前言
我们知道,现在会很多框架中都有集成数据集的库,我们可以通过几行简单的代码进行下载,如我们前几篇博客中的 MNIST 手写数字集和 Fashion MNIST数据集均可以通过PyTorch中的torchvision库下载。但库中收集到的数据集始终有限,我们则需要利用爬虫等技术进行收集新数据集,如果要用于模型训练,则需要我们自定义数据集。
1 猫狗数据集
猫狗数据集可以从kaggle官网下载,该数据集包含test1文件夹和train文件夹,train文件夹中包含12500张猫的图片和12500张狗的图片,图片的文件名中带序号:
2 自定义数据加载
实际应用中,样本以及样本标签的存储方式可能各不相同,如有些场合所有的图片存储在同一目录下,类别名可从图片名字中推导出。有些数据集样本的标签信息保存为 JSON 格式的文本文件中,需要按照 JSON 格式查询每个样本的标签。不管数据集是以什么方式存储的,我们总是能够用过逻辑规则获取所有样本的路径和标签信息。
2.1 创建编码表
样本的类别一般以字符串类型的类别名标记,但是对于神经网络来说,首先需要将类别名进行数字编码,然后在合适的时候再转换成 One-hot 编码或其他编码格式。考虑𝑛个类 别的数据集,我们将每个类别随机编码为𝑙 ∈ [0, 𝑛 − 1]的数字,类别名与数字的映射关系称编码表,一旦创建后,一般不能变动。
针对猫狗数据集,我用0和1来标注其种类。首先遍历train文件夹得所有子目录(图片),再以类别名作为字典的键,后以编码表的现有键值对数量作为类别的标签映射数字,并保存至字典对象。
1 | def load_data(root) : |
2.2 创建样本和标签表格
编码表确定后,我们需要根据实际数据的存储方式获得每个样本的存储路径以及它的标签数字,最终保存为csv文件,csv文件格式 是一种以逗号为分隔符号的纯文本格式。通过将所有样本信息存储在一个 csv 文件中有诸多好处,比如可以直接进行数据集的划分,可以随机采样 Batch 等。csv 文件中可以保存数据集所有样本的信息,也可以根据训练集、验证集和测试集分别创建 3 个 csv 文件。
1 | def create_csv(root, filename, name2label): |
创建完 csv 文件后,下一次只需要从 csv 文件中读取样本路径和标签信息即可,而不需要每次都生成 csv 文件,提高计算效率。
1 | def load_csv(filename, name2label): |
2.3 数据集划分
数据集的划分需要根据实际情况来灵活调整划分比率。当数据集样本数较多时,可以选择 80%-10%-10%的比例分配给训练集、验证集和测试集。这里猫狗的数据集因为训练集和测试集各有12500张且已分开,因此只需要从训练集中划分10% 作为验证集。对于小型的数据集,尽管样本数量较小,但还是需要适当增加验证集和测试集的比例,以保证获得准确的测试结果。
首先调用 load_csv 函数加载 images 和 labels 列表,根据当前模式参数 mode 加载对应部分的图片和标签。具体地,如果模式参数为 train,则分别取 images 和 labels 的前 90%数据作为训练集;如果模式参数为 val,则分别取 images 和 labels 的 90%到 100%区域数据作为验证集。(因为猫狗数据集已经有测试集,所以不再考虑)
1 | def load_cat_dog(filename, mode = 'train') : |
3 PyTorch实现自定义数据集
首先我们先回顾一下在线下载MNIST手写数字数据集及训练的大致过程代码过程。
1 | # 数据下载 |
从上面可知,PyTorch中数据传递机制是这样的:
- 创建Dataset
- Dataset传递给DataLoader
- DataLoader迭代产生训练数据提供给模型
数据集的自定义过程主要体现在创建Dataset过程,编写过程要继承torch.utils.data.Dataset ,该类是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类。必须要重写getitem(self, index)、 len(self) 两个内建方法,用来表示从索引到样本的映射(Map),重写之后我们可以直接使用下标来获取想要的数据。我们在使用 torch.utils.data.DataLoader 构建数据集时传入的参数是 images 和 labels 组成的 tuple,因此在对数据对象迭代时,返回的是(𝑿𝑖, 𝒀𝑖)的 tuple 对象,其中𝑿𝑖是第𝑖 个 Batch 的图片张量数据,𝒀𝑖是第𝑖个 Batch 的图片标签数据。代码如下,还是以猫狗数据集为例,且在前面我们已经将训练集和测试集数据保存到 train.csv 和 test.csv 文件(注意测试集和训练集的区别,测试集是没有标签的)中,我这里代码只显示训练集的获取,测试集类似。
1 | import torch |
运行结果: