互联网技术 / 互联网资讯 · 2024年3月9日

PyTorch 的 Torchvision 与 Torchtext 数据集概览

本文梳理了 PyTorch 中常用的数据加载与处理资源,聚焦 Torchvision 与 Torchtext 提供的内置数据集及其应用方式,帮助读者快速了解可直接使用的数据源与加载流程。

Torchvision 提供了多种图像数据集,用于训练与评估视觉模型;Torchtext 则提供文本与情感分析相关的数据集,方便进行自然语言处理实验。

常用的 Torchvision 数据集

MNIST:一个标准化、中心裁剪的手写数字图像数据集,包含超过 60,000 张训练样本和 10,000 张测试样本。常用于入门与实验性任务。加载方式通常通过 Torchvision.datasets.MNIST。

Fashion MNIST:与 MNIST 类似的结构,但数据集中包含服装类别(如 T 恤、裤子、背包等),训练与测试样本各 60,000 与 10,000。加载方式通常通过 Torchvision.datasets.FashionMNIST。

CIFAR:包含 CIFAR-10 与 CIFAR-100 两个版本,前者有 10 类、后者有 100 类,图像涵盖多种日常对象如车辆、动物等。

COCO:覆盖超过 10 万个日常对象,常用于目标检测与图像描述任务。加载方式通常通过 Torchvision.datasets.CocoCaptions。

EMNIST:MNIST 的扩展版本,包含数字与字母,适用于文本识别相关的图像任务。加载入口通常为 Torchvision.datasets.EMNIST。

ImageNet:旗舰级数据集,涵盖约 120 万张图像,分为 10,000 余个类别,通常用于训练高端模型,硬件资源要求较高。加载入口通常为 Torchvision.datasets.ImageNet。

常用的 Torchtext 数据集

IMDB:情感分析数据集,包含 25,000 条训练样本与 25,000 条测试样本,用于评估文本分类能力。加载入口为 Torchtext.datasets.IMDB。

WikiText-2:大型语言建模数据集,包含超过 1 亿个标记,保留了原文的标点与大小写,广泛用于长期依赖建模等任务。加载入口为 Torchtext.datasets.WikiText2。

从 MNIST 演示数据加载流程

MNIST 是最受欢迎的数据集之一,我们以一个示例演示如何从 Torchvision 下载并加载数据到变量 data_train 中。

示例说明:下载并加载 MNIST 数据到 data_train;使用 Matplotlib 显示一张样本图像及其标签;接着展示如何通过 DataLoader 将数据分批处理。

示例要点包括:
– 使用 MNIST(…) 下载并加载数据
– 将数据转换为张量格式
– 通过 DataLoader 实现批次加载与混洗

从 Torchvision 导入 MNIST 数据集并结合 torch.utils.data.Dataloader 及 transforms 的常用做法如下所示:
[[[IMG_1]]]
[[[IMG_2]]]

imageFolder 是一个通用数据加载器类,用于加载自定义图像数据集。通过将图像按根目录结构组织,可以训练一个简单的图像分类模型,例如识别橙色与苹果图像。典型结构如下:

Root ── orange/ ── orange_image1.png ── apple/ ── apple_image1.png

imageFolder(Root, transforms) 表示对图像应用的变换,transform 可以包括裁剪、缩放、归一化等操作,以便输入网络。

图像变换与数据预处理

常用的图像变换组合包括:将图像缩放为统一尺寸、中心裁剪、转为张量、以及标准化。通过组合 transforms,可以在一个步骤中完成所有处理,从而确保数据在训练阶段具有稳定的一致性。

示例:对 CIFAR-10 的数据应用一组变换,包含缩放、中心裁剪与张量化,并进行标准化处理。随后将结果传入数据加载器以实现分批训练。

以下示例展示了如何导入必要模块并应用变换:
– 导入 Torch、Torchvision 与 transforms
– 使用 Compose 组合多个 transforms
– 应用 Resize、CenterCrop、ToTensor、Normalize 等变换

如果需要在训练中开启 GPU 加速,可以检测 CUDA 是否可用,并据此创建 DataLoader。示例逻辑包括:

device = “cuda” if torch.cuda.is_available() else “cpu”
num_workers = 1
pin_memory = True if device == “cuda” else False
train_loader = torch.utils.data.DataLoader(MNIST(“/files/”, train=True, download=True), batch_size=64, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, transform=transforms)
test_loader = torch.utils.data.DataLoader(MNIST(“/files/”, train=False, download=True), batch_size=64, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, transform=transforms)

通过 imageFolder,可以实现自定义数据集的加载,而 transforms 组合则是实现高效预处理的关键工具。