YOLOv5目标检测代码精读

目录

YOLOv5目标检测代码精读

本文深入分析YOLOv5训练流程与数据增强机制,帮助个人梳理总结Yolov5这一目标检测模型的内部实现细节。


1. train.py 文件解析

1.1 Import 部分

import argparse
import math
import os
import random
import subprocess
import sys
import time
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path

try:
    import comet_ml  # must be imported before torch (if installed)
except ImportError:
    comet_ml = None

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.optim import lr_scheduler
from tqdm import tqdm

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

import val as validate  # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download, is_url
from utils.general import (
    LOGGER,
    TQDM_BAR_FORMAT,
    check_amp,
    check_dataset,
    check_file,
    check_git_info,
    check_git_status,
    check_img_size,
    check_requirements,
    check_suffix,
    check_yaml,
    colorstr,
    get_latest_run,
    increment_path,
    init_seeds,
    intersect_dicts,
    labels_to_class_weights,
    labels_to_image_weights,
    methods,
    one_cycle,
    print_args,
    print_mutation,
    strip_optimizer,
    yaml_save,
)
from utils.loggers import LOGGERS, Loggers
from utils.loggers.comet.comet_utils import check_comet_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve
from utils.torch_utils import (
    EarlyStopping,
    ModelEMA,
    de_parallel,
    select_device,
    smart_DDP,
    smart_optimizer,
    smart_resume,
    torch_distributed_zero_first,
)

LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))
RANK = int(os.getenv("RANK", -1))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
GIT_INFO = check_git_info()

1.2 Train() 函数详解

Train()函数是YOLOv5训练的核心函数,负责整个训练流程的管理:

def train(hyp, opt, device, callbacks):
    """
    Train a YOLOv5 model on a custom dataset using specified hyperparameters, options, and device, managing datasets,
    model architecture, loss computation, and optimizer steps.

    Args:
        hyp (str | dict): Path to the hyperparameters YAML file or a dictionary of hyperparameters.
        opt (argparse.Namespace): Parsed command-line arguments containing training options.
        device (torch.device): Device on which training occurs, e.g., 'cuda' or 'cpu'.
        callbacks (Callbacks): Callback functions for various training events.

    Returns:
        None

    Models and datasets download automatically from the latest YOLOv5 release.

    Example:
        Single-GPU training:
        ```bash
        $ python train.py --data coco128.yaml --weights yolov5s.pt --img 640  # from pretrained (recommended)
        $ python train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640  # from scratch
        ```

        Multi-GPU DDP training:
        ```bash
        $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 train.py --data coco128.yaml --weights
        yolov5s.pt --img 640 --device 0,1,2,3
        ```

        For more usage details, refer to:
        - Models: https://github.com/ultralytics/yolov5/tree/master/models
        - Datasets: https://github.com/ultralytics/yolov5/tree/master/data
        - Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
    """

函数执行的主要步骤包括:

  1. 参数解析与初始化:处理输入参数,设置保存目录和训练配置
  2. 加载超参数:从YAML文件加载或使用传入的超参数字典
  3. 配置日志系统:初始化日志记录器和回调函数
  4. 加载数据集:验证数据集格式并获取训练和验证路径
  5. 模型创建或加载:根据配置创建新模型或加载预训练权重
  6. 优化器配置:设置优化器、学习率调度器和EMA(指数移动平均)
  7. 数据加载器创建:创建训练和验证数据加载器
  8. 开始训练循环:执行多轮训练,每轮包括:
    • 训练阶段(前向传播、损失计算、反向传播、参数更新)
    • 可选的验证阶段(计算mAP等指标)
    • 模型保存和早停检查
  9. 训练结束:保存最终模型,进行最后验证,释放资源

特别值得注意的是训练数据加载器的创建部分:

# Trainloader
train_loader, dataset = create_dataloader(
    train_path,
    imgsz,
    batch_size // WORLD_SIZE,
    gs,
    single_cls,
    hyp=hyp,
    augment=True, # 数据增强在训练中默认开启
    cache=None if opt.cache == "val" else opt.cache,
    rect=opt.rect,
    rank=LOCAL_RANK,
    workers=workers,
    image_weights=opt.image_weights,
    quad=opt.quad,
    prefix=colorstr("train: "),
    shuffle=True,
    seed=opt.seed,
)

对比的验证数据加载器中没有启用数据增强:

# Process 0
# 验证数据加载器中无数据增强:
# 相比之下,验证数据加载器中没有启用数据增强,这是合理的,因为验证应该在不增强的原始数据上进行:
# 注意这里没有指定augment参数,它会使用create_dataloader函数的默认值False。
if RANK in {-1, 0}:
    val_loader = create_dataloader(
        val_path,
        imgsz,
        batch_size // WORLD_SIZE * 2,
        gs,
        single_cls,
        hyp=hyp,
        cache=None if noval else opt.cache,
        rect=True,
        rank=-1,
        workers=workers * 2,
        pad=0.5,
        prefix=colorstr("val: "),
    )[0]

1.3 训练中的数据增强

YOLOv5的训练默认开启了数据增强。根据代码分析,我们可以确认以下几点:

  1. 训练数据加载器中的数据增强设置: 在train.py文件的create_dataloader调用中,augment参数被明确设置为True

    train_loader, dataset = create_dataloader(
        train_path,
        imgsz,
        batch_size // WORLD_SIZE,
        gs,
        single_cls,
        hyp=hyp,
        augment=True,  # 数据增强在训练中默认开启
        cache=None if opt.cache == "val" else opt.cache,
        rect=opt.rect,
        rank=LOCAL_RANK,
        workers=workers,
        image_weights=opt.image_weights,
        quad=opt.quad,
        prefix=colorstr("train: "),
        shuffle=True,
        seed=opt.seed,
    )
    
  2. 验证数据加载器中无数据增强: 相比之下,验证数据加载器中没有启用数据增强,这是合理的,因为验证应该在不增强的原始数据上进行。

LoadImagesAndLabels类的__getitem__方法中,当augment=True时会应用多种数据增强技术:

  1. Mosaic增强:将4张不同图像合成一张,增加多尺度训练和小物体检测能力

    if mosaic := self.mosaic and random.random() < hyp["mosaic"]:
        img, labels = self.load_mosaic(index)
    
  2. MixUp增强:将两张图像按一定比例混合,增加训练数据的复杂性

    if random.random() < hyp["mixup"]:
        img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
    
  3. 随机透视变换:包括旋转、平移、缩放、剪切等几何变换

    img, labels = random_perspective(
        img,
        labels,
        degrees=hyp["degrees"],
        translate=hyp["translate"],
        scale=hyp["scale"],
        shear=hyp["shear"],
        perspective=hyp["perspective"],
    )
    
  4. Albumentations库增强:一个强大的图像增强库提供的额外增强

    img, labels = self.albumentations(img, labels)
    
  5. HSV颜色空间增强:调整色调、饱和度和亮度

    augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
    
  6. 随机翻转:上下和左右翻转

    if random.random() < hyp["flipud"]:
        img = np.flipud(img)
    
    if random.random() < hyp["fliplr"]:
        img = np.fliplr(img)
    
  7. Cutout(注释掉的):随机遮挡图像中的某些区域以增强模型的鲁棒性

1.4 增强参数控制

数据增强的具体参数通过超参数文件(hyp.yaml)控制,包括:

参数描述作用
mosaicMosaic增强的应用概率控制是否应用Mosaic增强
mixupMixUp增强的应用概率控制是否应用MixUp增强
hsv_hHSV色调调整强度控制色调变化范围
hsv_sHSV饱和度调整强度控制饱和度变化范围
hsv_vHSV亮度调整强度控制亮度变化范围
degrees旋转角度范围控制随机旋转的最大角度
translate平移范围控制随机平移的最大比例
scale缩放范围控制随机缩放的最大比例
shear剪切范围控制随机剪切的最大角度
perspective透视变换强度控制透视变换的强度
flipud上下翻转概率控制上下翻转的概率
fliplr左右翻转概率控制左右翻转的概率

1.5 结论

YOLOv5在训练过程中默认开启了丰富的数据增强策略,这是模型能够实现高检测性能的关键因素之一。这些增强包括图像融合(Mosaic和MixUp)、几何变换、颜色调整和随机翻转等。这些技术共同作用,大大增加了训练数据的多样性,帮助模型学习更加鲁棒的特征,提高了对不同环境和条件下目标的检测能力。

当使用YOLOv5训练自己的数据集时,无需手动开启数据增强,它已经默认启用。如果需要调整增强的强度,可以修改超参数文件中的相关参数。


2. YOLOv5数据加载与增强流程

整个数据加载和增强过程涉及多个函数和类的调用关系。下面详细解释这个流程:

2.1 调用关系

  1. 首先在train.py中调用create_dataloader函数:

    train_loader, dataset = create_dataloader(
        train_path,
        imgsz,
        batch_size // WORLD_SIZE,
        gs,
        single_cls,
        hyp=hyp,
        augment=True,  # 这里设置了augment=True
        cache=None if opt.cache == "val" else opt.cache,
        rect=opt.rect,
        rank=LOCAL_RANK,
        workers=workers,
        image_weights=opt.image_weights,
        quad=opt.quad,
        prefix=colorstr("train: "),
        shuffle=True,
        seed=opt.seed,
    )
    
  2. create_dataloader函数内部创建LoadImagesAndLabels数据集实例:

    dataset = LoadImagesAndLabels(
        path,
        imgsz,
        batch_size,
        augment=augment,  # 这里augment参数被传递给LoadImagesAndLabels
        hyp=hyp,
        # 其他参数...
    )
    
  3. create_dataloader函数最后返回一个PyTorchDataLoader和数据集:

    return loader(
        dataset,
        batch_size=batch_size,
        # 其他参数...
    ), dataset
    

2.2 数据增强的具体实现

数据增强发生在LoadImagesAndLabels类的__getitem__方法中,当训练过程需要一个批次数据时,这个方法会被调用:

  1. augment=True时,LoadImagesAndLabels类在初始化时会设置:

    self.augment = augment
    self.mosaic = self.augment and not self.rect  # 只有augment=True时才启用mosaic
    self.albumentations = Albumentations(size=img_size) if augment else None
    
  2. __getitem__方法中,如果self.augment=True,则会应用各种增强:

    if self.augment:
        # 随机透视变换
        img, labels = random_perspective(...)
    
        # Albumentations库增强
        img, labels = self.albumentations(img, labels)
    
        # HSV颜色空间增强
        augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
    
        # 随机翻转
        if random.random() < hyp["flipud"]:
            img = np.flipud(img)
    
        if random.random() < hyp["fliplr"]:
            img = np.fliplr(img)
    

2.3 完整的调用链

实际的调用链是这样的:

  1. train.py → 调用create_dataloader函数
  2. create_dataloader → 创建LoadImagesAndLabels类实例
  3. create_dataloader → 使用上面的数据集创建并返回DataLoader
  4. 当训练循环执行时,DataLoader → 调用LoadImagesAndLabels.__getitem__
  5. LoadImagesAndLabels.__getitem__ → 根据augment=True应用各种数据增强
graph TD
    A[train.py] -->|调用| B[create_dataloader]
    B -->|创建| C[LoadImagesAndLabels]
    B -->|返回| D[DataLoader]
    D -->|训练循环请求数据| E[__getitem__]
    E -->|应用| F[数据增强]

2.4 结论

train.py中设置的augment=True参数最终会被传递到LoadImagesAndLabels类,并在该类的__getitem__方法中触发各种数据增强操作。这是一个典型的PyTorch数据加载流程:先定义数据集类(处理单个样本的加载和增强),然后用DataLoader包装它(处理批次、多线程等)。

YOLOv5采用这种设计使得:

  1. 代码结构清晰(数据加载和模型训练分离)
  2. 数据处理高效(多线程预加载)
  3. 增强操作灵活(可以根据需要开启或关闭)

这就是为什么在train.py中设置augment=True后,系统能够自动应用复杂的数据增强策略。


3. YOLOv5的数据加载机制详解

3.1 数据加载器的创建流程

在YOLOv5的训练流程中:

  1. create_dataloader函数创建数据加载器

    • 这个函数首先创建一个LoadImagesAndLabels类的实例作为数据集
    • 然后将这个数据集包装在PyTorch的DataLoaderInfiniteDataLoader
    • 最后返回这个数据加载器和数据集
  2. LoadImagesAndLabels类充当数据集

    • 这个类继承自PyTorch的Dataset
    • 它负责管理数据的加载、预处理和增强
    • 它定义了如何获取单个数据样本的逻辑

3.2 LoadImagesAndLabels类的主要参数

这个类包含许多重要参数,以下是主要的几个:

3.2.1 基本路径和图像设置

  • path:数据集路径(可以是目录或文件列表)
  • img_size:图像大小(默认640像素)
  • batch_size:批次大小

3.2.2 增强相关参数

  • augment:是否启用数据增强
  • hyp:超参数字典,包含各种增强的概率和强度
  • mosaic:是否使用Mosaic增强(在augment=Truerect=False时自动启用)
  • albumentations:是否使用Albumentations库的增强

3.2.3 批次和处理相关参数

  • rect:是否使用矩形训练(一个批次中使用相似宽高比的图像)
  • stride:模型的最大下采样率,用于确保图像尺寸是步长的倍数
  • pad:边界填充大小

3.2.4 缓存和性能优化参数

  • cache_images:是否缓存图像以加速训练(可以是"ram"或"disk")
  • workers:数据加载的工作线程数

3.2.5 数据集特性参数

  • single_cls:是否将所有类别视为一个类别
  • image_weights:是否使用图像权重(基于类频率)

3.3 createdataloader函数的主要功能

这个函数完成几个关键任务:

def create_dataloader(
    path,
    imgsz,
    batch_size,
    stride,
    single_cls=False,
    hyp=None,
    augment=False,
    cache=False,
    pad=0.0,
    rect=False,
    rank=-1,
    workers=8,
    image_weights=False,
    quad=False,
    prefix="",
    shuffle=False,
    seed=0,
):
    # 创建数据集实例
    dataset = LoadImagesAndLabels(
        path,
        imgsz,
        batch_size,
        # 其他参数...
    )
  
    # 配置批次大小和采样器
    batch_size = min(batch_size, len(dataset))
    sampler = None if rank == -1 else distributed.DistributedSampler(...)
  
    # 选择数据加载器类型
    loader = InfiniteDataLoader if image_weights else DataLoader
  
    # 创建并返回数据加载器
    return loader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle and sampler is None,
        # 其他参数...
    ), dataset
  1. 创建数据集:实例化LoadImagesAndLabels
  2. 决定批次大小:确保批次大小不超过数据集大小
  3. 设置采样器:根据是否分布式训练选择适当的采样器
  4. 选择数据加载器类型:根据是否使用图像权重选择DataLoaderInfiniteDataLoader
  5. 配置数据加载参数
    • 批次大小
    • 是否随机打乱
    • 工作线程数
    • 采样器
    • 是否丢弃最后一个不完整批次
    • 内存固定
    • 数据整理函数(collate_fn)
    • 工作线程初始化函数
    • 随机数生成器

3.4 数据增强的实际发生位置

数据增强主要在LoadImagesAndLabels类的__getitem__方法中实际执行,下面是简化的方法流程:

def __getitem__(self, index):
    # 1. 获取索引
    index = self.indices[index]
  
    # 2. 决定是否使用Mosaic增强
    if self.mosaic and random.random() < self.hyp["mosaic"]:
        # 加载Mosaic增强图像
        img, labels = self.load_mosaic(index)
      
        # 可能应用MixUp增强
        if random.random() < self.hyp["mixup"]:
            img, labels = mixup(...)
    else:
        # 常规加载图像
        img, (h0, w0), (h, w) = self.load_image(index)
        # 应用Letterbox
        img, ratio, pad = letterbox(...)
        # 处理标签
        labels = self.labels[index].copy()
      
    # 3. 应用更多增强操作
    if self.augment:
        # 随机透视变换
        img, labels = random_perspective(...)
      
        # Albumentations库增强
        img, labels = self.albumentations(img, labels)
      
        # HSV颜色空间增强
        augment_hsv(...)
      
        # 随机翻转
        if random.random() < self.hyp["flipud"]:
            img = np.flipud(img)
          
        if random.random() < self.hyp["fliplr"]:
            img = np.fliplr(img)
  
    # 4. 最终处理
    # 标签格式转换
    labels_out = torch.zeros((len(labels), 6))
    # 图像格式转换
    img = img.transpose((2, 0, 1))[::-1]
  
    return torch.from_numpy(img), labels_out, self.im_files[index], shapes

当训练循环请求一个批次的数据时,这个方法会:

  1. 选择是否应用Mosaic增强(根据概率)
  2. 应用随机透视变换
  3. 应用Albumentations库增强
  4. 应用HSV颜色空间增强
  5. 应用随机翻转(上下和左右)
  6. 可选应用Cutout增强(当前似乎被注释掉了)

3.5 总结

整个YOLOv5数据加载过程是:

  1. train.py中调用create_dataloader函数,带有augment=True参数
  2. create_dataloader函数创建一个LoadImagesAndLabels类实例,并将augment=True传递给它
  3. create_dataloader将数据集包装在PyTorch的DataLoader
  4. 训练循环使用这个DataLoader获取数据批次
  5. 每次获取批次时,LoadImagesAndLabels类的__getitem__方法被调用,应用各种数据增强

这种设计使YOLOv5能够灵活地处理各种数据格式,并且应用复杂的数据增强策略,同时保持代码的模块化和可扩展性。


4. YOLOv5中数据增强的完整流程解析

4.1 完整调用流程

4.1.1 训练循环中的批次获取

train.py中的训练循环里,我们可以看到这样的代码:

pbar = enumerate(train_loader)
...
for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
    callbacks.run("on_train_batch_start")
    ni = i + nb * epoch  # number integrated batches (since train start)
    imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0
    ...

当循环遍历train_loader时,实际上是在调用PyTorch的数据加载流程。以下是完整的调用链:

  1. for i, (imgs, targets, paths, _) in pbar 这行代码触发了PyTorch的数据加载流程
  2. PyTorch的DataLoader会创建工作线程从数据集中获取样本
  3. DataLoader调用LoadImagesAndLabels.__getitem__(index)来获取单个样本
  4. DataLoader使用collate_fn函数将多个样本组合成一个批次
  5. 返回组合好的批次数据(imgs, targets, paths, _)给训练循环

4.1.2 __getitem__方法的调用时机

当训练流程需要加载一个批次的数据时:

  • 如果是第一次迭代,DataLoader会创建一个迭代器
  • 迭代器会根据批次大小和采样器确定要加载的样本索引
  • 对于每个索引,DataLoader会调用dataset[index],也就是LoadImagesAndLabels.__getitem__(index)
  • 这个方法会返回处理好的单个样本(图像和标签)
  • DataLoader将多个样本组合成一个批次返回给训练循环

4.1.3 __getitem__方法的数据增强实现

__getitem__方法中数据增强的步骤:

def __getitem__(self, index):
    """获取数据集中指定索引的样本,考虑线性、随机或加权采样。"""
    index = self.indices[index]  # 线性、随机或图像权重
  
    hyp = self.hyp
    if mosaic := self.mosaic and random.random() < hyp["mosaic"]:
        # 加载Mosaic增强
        img, labels = self.load_mosaic(index)
        shapes = None
      
        # MixUp增强
        if random.random() < hyp["mixup"]:
            img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
    else:
        # 常规加载图像
        img, (h0, w0), (h, w) = self.load_image(index)
      
        # Letterbox
        shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size
        img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
        shapes = (h0, w0), ((h / h0, w / w0), pad)
      
        # 处理标签
        labels = self.labels[index].copy()
        if labels.size:
            labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
      
        # 随机透视变换
        if self.augment:
            img, labels = random_perspective(
                img,
                labels,
                degrees=hyp["degrees"],
                translate=hyp["translate"],
                scale=hyp["scale"],
                shear=hyp["shear"],
                perspective=hyp["perspective"],
            )
  
    nl = len(labels)  # 标签数量
    if nl:
        labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
  
    # 更多数据增强操作
    if self.augment:
        # Albumentations库增强
        img, labels = self.albumentations(img, labels)
        nl = len(labels)  # 更新标签数量
      
        # HSV颜色空间增强
        augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
      
        # 上下翻转
        if random.random() < hyp["flipud"]:
            img = np.flipud(img)
            if nl:
                labels[:, 2] = 1 - labels[:, 2]
      
        # 左右翻转
        if random.random() < hyp["fliplr"]:
            img = np.fliplr(img)
            if nl:
                labels[:, 1] = 1 - labels[:, 1]
      
        # Cutout(被注释掉了)
        # labels = cutout(img, labels, p=0.5)
  
    # 格式转换
    labels_out = torch.zeros((nl, 6))
    if nl:
        labels_out[:, 1:] = torch.from_numpy(labels)
  
    # 图像格式转换:HWC到CHW,BGR到RGB
    img = img.transpose((2, 0, 1))[::-1]
    img = np.ascontiguousarray(img)
  
    return torch.from_numpy(img), labels_out, self.im_files[index], shapes

4.1.4 关键增强操作详解

a. Mosaic增强
  • 随机选择4张图像,将它们组合成一个大图像
  • 随机确定mosaic中心点位置
  • 调整四张图像的大小和位置
  • 调整对应的标签坐标
if mosaic := self.mosaic and random.random() < hyp["mosaic"]:
    img, labels = self.load_mosaic(index)
b. MixUp增强
  • 在Mosaic之后可能应用
  • 将两个Mosaic增强的样本按一定比例混合
  • 合并两个样本的标签
if random.random() < hyp["mixup"]:
    img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
c. 随机透视变换
  • 应用旋转、平移、缩放、剪切等几何变换
  • 同时调整标签坐标以匹配变换后的图像
if self.augment:
    img, labels = random_perspective(
        img,
        labels,
        degrees=hyp["degrees"],
        translate=hyp["translate"],
        scale=hyp["scale"],
        shear=hyp["shear"],
        perspective=hyp["perspective"],
    )
d. Albumentations库增强
  • 使用第三方库Albumentations提供的额外增强
  • 这是一个条件性操作,只有在初始化时创建了albumentations对象才会应用
if self.augment:
    img, labels = self.albumentations(img, labels)
e. HSV颜色空间增强
  • 在HSV颜色空间中调整色调、饱和度和亮度
  • 随机变化的强度由超参数控制
if self.augment:
    augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
f. 随机翻转
  • 随机进行上下和左右翻转
  • 同时调整标签坐标
if self.augment:
    if random.random() < hyp["flipud"]:
        img = np.flipud(img)
        if nl:
            labels[:, 2] = 1 - labels[:, 2]
  
    if random.random() < hyp["fliplr"]:
        img = np.fliplr(img)
        if nl:
            labels[:, 1] = 1 - labels[:, 1]

4.1.5 增强概率控制

每种增强操作的应用概率是通过超参数hyp控制的:

  • hyp["mosaic"]: Mosaic增强概率
  • hyp["mixup"]: MixUp增强概率
  • hyp["flipud"]: 上下翻转概率
  • hyp["fliplr"]: 左右翻转概率

其他增强操作的强度也由超参数控制:

  • hyp["degrees"]: 旋转角度范围
  • hyp["translate"]: 平移范围
  • hyp["scale"]: 缩放范围
  • hyp["shear"]: 剪切范围
  • hyp["perspective"]: 透视变换强度
  • hyp["hsv_h"], hyp["hsv_s"], hyp["hsv_v"]: HSV颜色空间调整强度

4.2 总结:数据增强的完整流程

  1. 触发时机:当训练循环通过 for i, (imgs, targets, paths, _) in pbar 迭代DataLoader
  2. 调用过程:PyTorch DataLoader → 工作线程 → LoadImagesAndLabels.__getitem__(index)collate_fn → 批次数据
  3. 预处理:加载图像,应用letterbox调整大小,调整标签坐标
  4. 主要增强
    • Mosaic增强(合并4张图像)
    • MixUp增强(混合2个样本)
    • 随机透视变换(旋转、平移、缩放、剪切)
    • Albumentations库增强
    • HSV颜色空间增强
    • 随机翻转(上下和左右)
  5. 格式转换:转换图像格式,准备好PyTorch需要的张量格式
  6. 返回结果:处理好的图像、标签、文件路径和形状信息
sequenceDiagram
    participant Train as 训练循环
    participant Loader as DataLoader
    participant Dataset as LoadImagesAndLabels
    participant Aug as 数据增强
  
    Train->>Loader: 迭代请求批次数据
    Loader->>Dataset: __getitem__(index)
    Dataset->>Dataset: 加载图像
    alt Mosaic增强
        Dataset->>Aug: load_mosaic()
        opt MixUp增强
            Dataset->>Aug: mixup()
        end
    else 常规加载
        Dataset->>Dataset: 加载图像+letterbox
        opt 随机透视变换
            Dataset->>Aug: random_perspective()
        end
    end
    opt 其他增强
        Dataset->>Aug: albumentations增强
        Dataset->>Aug: HSV颜色空间增强
        Dataset->>Aug: 随机翻转
    end
    Dataset->>Loader: 返回处理后的样本
    Loader->>Train: 返回批次数据

参考资料:

文章对话

由AI生成的"小T"和"好奇宝宝"之间的对话,帮助理解文章内容