前言
迁移学习在计算机视觉任务和自然语言处理任务中经常使用,这些模型往往需要大数据、复杂的网络结构。如果使用迁移学习,可将预训练的模型作为新模型的起点,这些预训练的模型在开发神经网络的时候已经在大数据集上训练好、模型设计也比较好,这样的模型通用性也比较好。如果要解决的问题与这些模型相关性较强,那么使用这些预训练模型,将大大地提升模型的性能和泛化能力。
1 原理
迁移学习(Transfer Learning)是机器学习的一个研究方向,主要研究如何将任务 A 上面学习到的知识迁移到任务 B 上,以提高在任务 B 上的泛化性能。例如任务 A 为猫狗分类问题,需要训练一个分类器能够较好的分辨猫和狗的样本图片,任务 B 为牛羊分类问题。可以发现,任务 A 和任务 B 存在大量的共享知识,比如这些动物都可以从毛发、体型、形 态、发色等方面进行辨别。因此在任务 A 训练获得的分类器已经掌握了这部份知识,在训练任务 B 的分类器时,可以不从零开始训练,而是在任务 A 上获得的知识的基础上面进行训练(Feature Extraction)或微调(Fine tuning),这和“站在巨人的肩膀上”思想非常类似。通过迁移任务 A 上学习的知识,在任务 B 上训练分类器可以使用更少的样本和更少的训练代价,并且获得不错的泛化能力。
在神经网络迁移学习中,主要有两个应用场景:特征提取和微调。
❑ 特征提取(Feature Extraction) :冻结除最终完全连接层之外的所有网络的权重。最后一个全连接层被替换为具有随机权重的新层,因只需要更新最后一层全连接层,使得更新参数极大地减少,节省大量的 训练时间 和 GPU 资源。
❑ 微调(Fine Tuning) :对于卷积神经网络,一般认为它能够逐层提取特征,越末层的网络的抽象特征提取能力越强,输出层一般使用与类别数相同输出节点的全连接层,作为分类网络的概率分布预测。对于相似的任务 A 和 B,如果它们的特征提取方法是相近的,则网络的前面数层可以重用。而微调技术就是使用预训练网络初始化网络,而不是随机初始化,用新数据训练部分或整个网络。小幅度更新前面的层的参数。
2 实例
进行迁移学习需要使用对应的预训练模型。PyTorch提供了很多现成的预 训练模块,我们直接拿来使用就可以。主要集成在 torchvision.models 模块中,预训练模型可以通过传递参数 pretrained = True 构造。主要的模型有 AlexNet,VGG,ResNet,SqueezeNet,DenseNet,Inception v3,GoogLeNet,ShuffleNet v2 等。
所有的预训练模型都要求输入图片以相同的方式进行标准化,即:小批l量3通道RGB格式 (3 × H × W) ,其中H和W应等于 224 。图片加载时像素值的范围应在 [0, 1] 内,然后通过指定 mean = [0.485, 0.456, 0.406] 和 std = [0.229, 0.224, 0.225] 进行标准化。
2.1 特征提取
本次案例使用的数据集是 CIFAR-10数据集 ,目标是对数据集中10类物体进行分类,只使用几层的卷积和全连接层的分类正确率只有 68% 左右,结果不算好。此案例使用迁移学习中特征提取方法来实现这个任务,预训练模型采用 retnet18 网络。
1 | #导入相关包 |
结果:
在前三个Epoch准确率就达到了 73.6% ,最终结果会达到 75% 左右,从精确率比第6章提升了近10个百分点。但是对于分类效果来说仍不是很理想。
2.2 微调
微调允许修改预先训练好的网络参数来学习目标任务,所以,虽然训练时间要比特征抽取方法长,但精度更高。微调的大致过程是在预先训练过的网络上添加新的随机初始化层,此外预先训练的网络参数也会被更新,但会使用较小的学习率以防止预先训练好的参数发生较大的改变。
在本次的微调任务中采用了数据增强的方法来使得分类效果更加。因为数据增强是提高模型的泛化能力最重要因素,数据增强技术主要有 水平或垂直翻转图像、裁剪、色彩 变换、扩展和旋转 等,通过数据增强技术不仅可以扩大训练数据集的规 模、降低模型对某些属性的依赖,从而提高模型的泛化能力,同时可以对图像进行不同方式的裁剪,使感兴趣的物体出现在不同的位置,从而减轻模型对物体出现位置的依赖性。并通过调整亮度、色彩等因素来降低模型对色彩的敏感度等。在PyTorch中图像增强的方法集成在 torchvision.transforms 模块中,主要的有:
❑ torchvision.transforms.Resize() :随机比例缩放。
❑ torhvision.transforms.RandomCrop() :在图像随机位置进行裁取。
❑ torhvision.transforms.CenterCrop() :在图像中心置进行裁取。
❑ torchvision.transforms.RandomHorizontalFlip() :随机水平翻转。
❑ torchvision.transforms.RandomVerticalFlip() :随机竖直翻转。
❑ torchvision.transforms.RandomRotation() :随机旋转。
❑ torchvision.transforms.ColorJitter() :改变亮度、对比度和颜色。
微调的代码与特征提取的不同地方主要在图像预处理部分和参数更新部分。
这里对训练数据添加了几种数据增强方法,如图像裁剪、旋转、颜色改变等方法。测试数据与特征提取一样,没有变化。
1 | trans_train = transforms.Compose([ |
优化器部分,注意不要冻结预训练模型的参数。
1 | optimizer = optim.SGD(net.parameters(), lr = 0.001, weight_decay = 0.001, momentum = 0.9) |
结果:
由结果知微调+数据增强的方法在第三个Epoch正确率就可以达到 92% ,最终结果可达到 95% 左右,正确很高。本次实验只设置了20个Eopch,当继续增加Epoch时,正确率会接近 100% 。
参考文献:
❑ :Python深度学习基于PyTorch
❑ :TensorFlow深度学习