第 5 章:数据加载与预处理¶
核心比喻:DataLoader = 传送带 —— 就像工厂的传送带自动把原材料送到工位,DataLoader 自动把数据分批送到模型。
5.1 PyTorch 数据加载体系¶
PyTorch 提供了两个核心类来管理数据:
| 类 | 作用 | 比喻 |
|---|---|---|
torch.utils.data.Dataset |
存储数据及其标签 | 仓库 |
torch.utils.data.DataLoader |
批量加载、打乱、并行处理 | 传送带 |
5.2 自定义 Dataset¶
从内存数据创建¶
class SimpleDataset(Dataset):
"""最简单的自定义数据集"""
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 创建模拟数据
X = torch.randn(1000, 10)
y = torch.randint(0, 3, (1000,))
dataset = SimpleDataset(X, y)
print(f"数据集大小: {len(dataset)}")
print(f"第 0 个样本: data shape={dataset[0][0].shape}, label={dataset[0][1]}")
运行结果:
从文件创建¶
import os
from PIL import Image
class ImageFolderDataset(Dataset):
"""从文件夹加载图像的数据集"""
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
# 遍历文件夹,每个子文件夹作为一个类别
for label, class_name in enumerate(sorted(os.listdir(root_dir))):
class_dir = os.path.join(root_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
if img_name.endswith(('.jpg', '.png', '.jpeg')):
self.image_paths.append(os.path.join(class_dir, img_name))
self.labels.append(label)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
print("ImageFolderDataset 类已定义(需要实际图像文件夹才能使用)")
5.3 DataLoader 详解¶
# 创建数据集
X = torch.randn(1000, 10)
y = torch.randint(0, 3, (1000,))
dataset = SimpleDataset(X, y)
# 创建 DataLoader
dataloader = DataLoader(
dataset,
batch_size=32, # 每批 32 个样本
shuffle=True, # 每个 epoch 打乱数据
num_workers=0, # 数据加载的子进程数(Windows 建议设为 0)
drop_last=False, # 是否丢弃最后不足一个 batch 的数据
pin_memory=False, # 是否将数据锁页(GPU 训练时设为 True 可加速)
)
# 遍历 DataLoader
for batch_idx, (data, labels) in enumerate(dataloader):
print(f"Batch {batch_idx}: data shape={data.shape}, labels shape={labels.shape}")
if batch_idx >= 2:
break
print(f"\n总批次数: {len(dataloader)}")
print(f"每批大小: {dataloader.batch_size}")
运行结果:
Batch 0: data shape=torch.Size([32, 10]), labels shape=torch.Size([32])
Batch 1: data shape=torch.Size([32, 10]), labels shape=torch.Size([32])
Batch 2: data shape=torch.Size([32, 10]), labels shape=torch.Size([32])
总批次数: 32
每批大小: 32
DataLoader 参数说明
| 参数 | 说明 | 建议值 |
|---|---|---|
batch_size |
每批样本数 | 32/64/128(根据显存调整) |
shuffle |
是否打乱 | 训练=True,验证=False |
num_workers |
并行加载进程数 | Linux: 4-8, Windows: 0 |
drop_last |
丢弃最后不完整 batch | 训练时若 BatchNorm 报错则设为 True |
pin_memory |
锁页内存 | GPU 训练时设为 True |
5.4 数据预处理(Transforms)¶
torchvision.transforms 提供了丰富的图像预处理操作:
from torchvision import transforms
# 定义预处理流水线
transform_pipeline = transforms.Compose([
transforms.Resize((256, 256)), # 调整大小
transforms.RandomCrop(224), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter( # 颜色抖动
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(), # 转为张量 [0, 1]
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
print("训练预处理流水线:")
for i, t in enumerate(transform_pipeline.transforms):
print(f" {i+1}. {t.__class__.__name__}")
# 验证集预处理(不需要数据增强)
val_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
print("\n验证预处理流水线:")
for i, t in enumerate(val_transform.transforms):
print(f" {i+1}. {t.__class__.__name__}")
运行结果:
训练预处理流水线:
1. Resize
2. RandomCrop
3. RandomHorizontalFlip
4. ColorJitter
5. ToTensor
6. Normalize
验证预处理流水线:
1. Resize
2. CenterCrop
3. ToTensor
4. Normalize
数据增强的重要性
数据增强是防止过拟合的 最有效手段之一: - 随机翻转/旋转:让模型学习方向不变性 - 颜色抖动:让模型适应光照变化 - 随机裁剪:让模型关注局部特征 - Cutout/RandomErasing:随机遮挡,增强鲁棒性
5.5 使用内置数据集¶
PyTorch 的 torchvision.datasets 提供了许多常用数据集:
from torchvision import datasets
# CIFAR-10
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
# Fashion-MNIST
fashion_mnist = datasets.FashionMNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
print(f"CIFAR-10 训练集: {len(cifar10_train)} 张图像, 类别数: {len(cifar10_train.classes)}")
print(f"CIFAR-10 类别: {cifar10_train.classes}")
print(f"\nFashion-MNIST 训练集: {len(fashion_mnist)} 张图像, 类别数: {len(fashion_mnist.classes)}")
print(f"Fashion-MNIST 类别: {fashion_mnist.classes}")
运行结果:
CIFAR-10 训练集: 50000 张图像, 类别数: 10
CIFAR-10 类别: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Fashion-MNIST 训练集: 60000 张图像, 类别数: 10
Fashion-MNIST 类别: ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
5.6 数据划分:训练集 / 验证集 / 测试集¶
from torch.utils.data import random_split
# 创建完整数据集
full_dataset = SimpleDataset(
torch.randn(10000, 10),
torch.randint(0, 3, (10000,))
)
# 按比例划分:80% 训练,10% 验证,10% 测试
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, [train_size, val_size, test_size]
)
print(f"训练集: {len(train_dataset)} 样本")
print(f"验证集: {len(val_dataset)} 样本")
print(f"测试集: {len(test_dataset)} 样本")
# 为每个数据集创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
运行结果:
5.7 文本数据加载¶
from torch.utils.data import Dataset
class TextDataset(Dataset):
"""文本分类数据集"""
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 分词
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'label': torch.tensor(label)
}
# 模拟数据
texts = [
"这部电影太棒了,强烈推荐",
"浪费时间,剧情无聊透顶",
"还不错,值得一看",
"演员演技在线,导演功力深厚",
]
labels = [1, 0, 1, 1]
print(f"文本数据集: {len(texts)} 条文本")
print(f"标签分布: 正面={labels.count(1)}, 负面={labels.count(0)}")
运行结果:
要点总结¶
-
Dataset定义数据存储和访问方式,必须实现__len__和__getitem__ -
DataLoader负责批量加载、打乱、并行处理 -
transforms.Compose组合多个预处理操作 - 训练集需要数据增强,验证/测试集只需要标准化
-
random_split方便地划分数据集 -
torchvision.datasets提供常用图像数据集 - 自定义 Dataset 可处理任意类型的数据(图像、文本、音频等)
课后练习¶
-
自定义 Dataset:创建一个 Dataset 类,从 CSV 文件加载数据(包含特征列和标签列)。
-
数据增强实验:对同一张图像应用不同的 transforms 组合,观察增强效果。
-
DataLoader 参数调优:尝试不同的
batch_size和num_workers,比较数据加载速度。 -
不平衡数据处理:实现一个加权采样器(
WeightedRandomSampler),处理类别不平衡的数据集。