YOLOv5的dataloaders.py代码精读
目录
在 yolov 5 目标检测任务中,我跑 train. Py 代码 ,在train_loader 中,那我就调用了create_dataloader 的函数,在该函数内部创建了LoadImagesAndLabels 类的实例作为 dataset,在这个 create_dataloader 函数最后返回一个 DataLoader和数据集:
-dataloaders. Py
PyTorch 数据集中的 __getitem__
方法工作原理
__getitem__
是 Python 中的一个特殊方法(魔术方法),在 YOLOv 5 的 LoadImagesAndLabels
类中用于访问数据集中的单个样本。当您使用数据加载器或直接通过索引访问数据集时,这个方法会被调用。
访问流程
当执行以下操作时,__getitem__
方法被调用:
- 直接从数据集访问:
image, label = dataset[5]
- 通过 DataLoader 迭代:
for images, labels in dataloader: ...
在 DataLoader 中,__getitem__
会被多次并行调用(由 num_workers
参数决定),然后结果通过 collate_fn
方法合并为批次。
YOLOv 5 中 __getitem__
的工作流程
def __getitem__(self, index):
# 1. 将传入的索引转换为实际使用的索引(处理线性、打乱或加权采样)
index = self.indices[index]
# 2. 检查是否应用mosaic增强(基于配置和随机概率)
if mosaic := self.mosaic and random.random() < hyp["mosaic"]:
# 加载mosaic增强的图像和标签
img, labels = self.load_mosaic(index) # 这里是您想修改为load_mosaic9的地方
shapes = None
# 检查是否进一步应用mixup增强
if random.random() < hyp["mixup"]:
img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))
else:
# 3. 不使用mosaic时的常规图像加载和处理流程
img, (h0, w0), (h, w) = self.load_image(index)
# Letterbox处理
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
shapes = (h0, w0), ((h / h0, w / w0), pad)
# 处理标签
labels = self.labels[index].copy()
if labels.size:
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
# 应用随机透视变换等增强
if self.augment:
img, labels = random_perspective(...)
# 4. 标签格式转换
nl = len(labels)
if nl:
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
# 5. 应用更多的增强技术(如果启用)
if self.augment:
# Albumentations增强
img, labels = self.albumentations(img, labels)
# HSV颜色空间增强
augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
# 上下翻转
if random.random() < hyp["flipud"]:
img = np.flipud(img)
if nl:
labels[:, 2] = 1 - labels[:, 2]
# 左右翻转
if random.random() < hyp["fliplr"]:
img = np.fliplr(img)
if nl:
labels[:, 1] = 1 - labels[:, 1]
# 6. 准备输出格式
labels_out = torch.zeros((nl, 6))
if nl:
labels_out[:, 1:] = torch.from_numpy(labels)
# 7. 图像格式转换:HWC->CHW, BGR->RGB
img = img.transpose((2, 0, 1))[::-1]
img = np.ascontiguousarray(img)
# 8. 返回最终处理好的数据
return torch.from_numpy(img), labels_out, self.im_files[index], shapes
DataLoader 如何使用 __getitem__
- PyTorch 的 DataLoader 创建多个工作进程(worker)
- 每个 worker 负责获取批次中的部分样本,调用数据集的
__getitem__
方法 - 所有样本收集完后,通过
collate_fn
函数合并为一个批次 - 最终批次传递给模型进行训练
修改 mosaic 增强
要将 4 图像 mosaic 修改为 9 图像 mosaic,您只需修改第 478 行的 self.load_mosaic(index)
为 self.load_mosaic9(index)
。这样当启用 mosaic 增强时,系统会使用 9 张图片而不是 4 张图片来创建马赛克增强效果。
这种修改不会影响 __getitem__
的整体工作流程,只是改变了具体使用的增强方法。
Train. Py
train_loader, dataset = create_dataloader(
train_path,
imgsz,
batch_size // WORLD_SIZE,
gs,
single_cls,
hyp=hyp,
augment=True, # 数据增强在训练中默认开启
cache=None if opt.cache == "val" else opt.cache,
rect=opt.rect,
rank=LOCAL_RANK,
workers=workers,
image_weights=opt.image_weights,
quad=opt.quad,
prefix=colorstr("train: "),
shuffle=True,
seed=opt.seed,
)
Dataloader 迭代器
- 创建数据加载器:
# Trainloader
train_loader, dataset = create_dataloader(
train_path,
imgsz,
batch_size // WORLD_SIZE,
gs,
single_cls,
hyp=hyp,
augment=True, # 数据增强在训练中默认开启
cache=None if opt.cache == "val" else opt.cache,
rect=opt.rect,
rank=LOCAL_RANK,
workers=workers,
image_weights=opt.image_weights,
quad=opt.quad,
prefix=colorstr("train: "),
shuffle=True,
seed=opt.seed,
)
# 训练循环
for epoch in range(start_epoch, epochs):
model.train()
# ...
pbar = enumerate(train_loader)
# ...
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
# 处理批次数据
# ...
- 训练循环中的迭代
# Trainloader
train_loader, dataset = create_dataloader(
train_path,
imgsz,
batch_size // WORLD_SIZE,
gs,
single_cls,
hyp=hyp,
augment=True, # 数据增强在训练中默认开启
cache=None if opt.cache == "val" else opt.cache,
rect=opt.rect,
rank=LOCAL_RANK,
workers=workers,
image_weights=opt.image_weights,
quad=opt.quad,
prefix=colorstr("train: "),
shuffle=True,
seed=opt.seed,
)
# 训练循环
for epoch in range(start_epoch, epochs):
model.train()
# ...
pbar = enumerate(train_loader)
# ...
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
# 处理批次数据
# ...
工作原理
DataLoader作为一个迭代器,内部维护了对数据集的访问机制
当执行for循环迭代它时,它会:
根据batch_size确定需要获取哪些索引
将这些工作分配给多个工作进程(由workers参数控制)
每个工作进程调用数据集的__getitem__方法获取对应索引的样本
使用collate_fn将这些样本组合成一个批次
返回组合好的批次
- 对于YOLOv5,使用了特殊的InfiniteDataLoader类,它是对标准DataLoader的扩展,目的是重用工作进程以提高效率
WORLD_SIZE 参数
#分布式
WORLD_SIZE
是与分布式训练(Distributed Data Parallel, DDP)相关的参数。
在分布式训练环境中:
WORLD_SIZE
表示参与训练的总进程数,通常等于使用的 GPU 数量- 当在多个 GPU 上进行并行训练时,每个 GPU 上运行一个进程
- 总批次大小(batch_size)需要在所有 GPU 之间平均分配
例如,如果设置批次大小为 32,并且使用 4 个 GPU(WORLD_SIZE=4
)进行分布式训练,那么每个 GPU 实际处理的批次大小就是 32/4=8。
表达式 batch_size // WORLD_SIZE
确保总批次大小被平均分配到每个进程(GPU)上,这样每个进程将处理更小的批次,但所有进程一起处理的数据量等于指定的总批次大小。
在单 GPU 或 CPU 训练时,WORLD_SIZE
默认为 1,这种情况下 batch_size // WORLD_SIZE
就等于 batch_size
。
WORLD_SIZE 在 YOLOv 5 中的作用
WORLD_SIZE
是与分布式训练(Distributed Data Parallel,简称 DDP)相关的参数,表示参与训练的总 GPU 进程数。
WORLD_SIZE 详解
定义:
WORLD_SIZE
表示参与训练的 GPU 总数(或进程总数)- 在代码中通常通过环境变量设置:
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
- 默认值为 1,表示单 GPU 或 CPU 训练
作用:
- 用于将总批次大小平均分配到每个 GPU
- 确保全局批次大小在分布式训练中正确维持
在数据加载中的应用:
batch_size // WORLD_SIZE
计算每个 GPU 负责处理的样本数量- 例如,总批次大小为 32,使用 4 个 GPU(WORLD_SIZE=4),则每个 GPU 处理 8 个样本
实际应用场景
- 单 GPU 训练:WORLD_SIZE=1,每个 GPU 处理完整批次
- 多 GPU 训练:例如 4 个 GPU 时,WORLD_SIZE=4,批次均分
启动方式不同的影响
普通训练:
python train.py --batch-size 16
- WORLD_SIZE=1,每个批次 16 个样本
分布式训练:
python -m torch.distributed.run --nproc_per_node 4 train.py --batch-size 16
- WORLD_SIZE=4,每个 GPU 处理 4 个样本(16/4=4)
- 总体仍处理 16 个样本,但跨 4 个 GPU 并行
这种设计使得 YOLOv 5 可以无缝地在单 GPU 和多 GPU 环境中运行,只需调整启动命令,而不需要修改代码逻辑。
Dataloaders. Py 文件中核心参数
1. 创建数据加载器的核心参数 (create_dataloader 函数)
- path: 数据集路径,决定加载哪些数据
- imgsz: 图像大小,影响模型输入分辨率和内存消耗
- batch_size: 批次大小,影响训练速度和内存使用
- stride: 模型步长,用于计算图像尺寸和网格对齐
- augment: 是否启用数据增强,训练时通常为 True
- rect: 是否使用矩形训练,可以减少填充提高效率
- cache: 缓存模式 (“ram”/“disk”/None),影响加载速度
- workers: 数据加载线程数,影响 CPU 使用和加载速度
- image_weights: 是否使用图像权重采样
2. 数据集类的关键参数 (LoadImagesAndLabels 类)
- mosaic: 马赛克增强,我们讨论的修改点
- mosaic_border: 马赛克边界值,控制马赛克图像布局
- hyp: 超参数字典,包含所有数据增强参数
- img_size: 目标图像大小
- rect: 矩形训练模式,按宽高比对图像进行分组
- albumentations: 外部增强库的集成
3. 数据增强相关参数 (hyp 字典中)
- mosaic: 马赛克增强概率
- mixup: 混合增强概率
- copy_paste: 分割复制粘贴概率
- hsv_h/hsv_s/hsv_v: HSV 颜色空间增强参数
- flipud/fliplr: 上下/左右翻转概率
- degrees/translate/scale/shear: 几何变换参数
- perspective: 透视变换参数
4. 性能优化参数
- cache_images: 是否缓存图像到 RAM 或磁盘
- pin_memory: 是否使用内存锁定加速 GPU 传输
- rect: 矩形训练参数,减少填充
- NUM_THREADS: 多进程操作的线程数
5. 分布式训练参数
- WORLD_SIZE: 参与训练的 GPU 总数
- RANK: 当前进程的全局排名
- LOCAL_RANK: 当前进程的本地排名
关键数据加载方法
- getitem: 加载和处理单个样本的核心方法
- load_mosaic/load_mosaic 9: 马赛克增强方法
- load_image: 基本图像加载方法
- collate_fn: 将多个样本组合成批次的方法
文章对话
由AI生成的"小T"和"好奇宝宝"之间的对话,帮助理解文章内容