1 文章简述
该文章的主要思想是将预测概率最大的标记作为无标记数据的伪标签,然后给未标记数据设一个权重,在训练过程中慢慢增加未标记数据的权重来进行训练,在手写体数据集上有了较好的性能。算法流程如下:
输入:样本集 $\boldsymbol{D1 = \{(x_1, y_1),(x_2,y_2),\cdots,(x_n, y_n)\}, D2 = \{x_1,x_2,,\cdots,x_n\}}$,其中 $\boldsymbol{D_1}$ 为已标注数据,$\boldsymbol{D_2}$ 为未标注数据;
过程:
1:用 $\boldsymbol{D_1}$ 来训练得到一个初始分类器;
2:用初始分类器对 $\boldsymbol{D_2}$ 进行分类,将预测的最大预测概率的类对 $\boldsymbol{D_2}$ 进行标注,得到伪标签(Pseudo-Label);
3:使用 $\boldsymbol{D_1}$ 和得到伪标签的样本进行训练,再对 $\boldsymbol{D_2}$ 进行分类、标注,直至所有的样本均被标注;
4:使用最终网络对 MNIST 数据集进行分类;
输出:MNIST 测试集上的分类错误率。
2 伪标签(Pseudo-Label)
伪标签是未标注样本的的标签,是对当前未标注样本的预测的最大概率对应的类别,在文章中表示如下:
Pseudo-Label 用于 Dropout 的微调阶段。预训练网络以监督方式同时使用标记和未标记数据进行训练,整体的损失函数表示为:
其中:
- $n$ 是已标注样本一个 batch_size 的样本数,$n’$ 是未标注样本的一个 batch_size 的样本数,$C$ 是类别数;
- $f_i^m$ 是已标注样本的输出,$y_i^m$ 是已标注样本的真实标签;
- $f_i^{‘m}$ 是未标注样本的输出,$y_i^{‘m}$ 是未标注样本的伪标签;
- $\alpha(t)$ 是平衡已标注样本损失和未标注样本损失的系数,对模型的性能有着至关重要的影响。如果 $\alpha(t)$ 太大,会对训练产生很大成都的干扰;而如果 $\alpha(t)$ 太小,则不能利用未标注数据。
在原文中,作者 $\alpha(t)$ 缓慢增加,以帮助优化过程避免较差的局部最优点:
其中:
- $\alpha(t) = 3,T_1 = 100,T_2 = 600$;
- 即在 100 个 epoch 之前,$\alpha(t)$ 为0,此时只在有标签的数据上进行训练;
- 在100-600 个 epoch 之间时,$\alpha(t)$ 设置为(epoch_current - 100) / 500 * 3;
- 当大于 600 个 epoch 的时候,$\alpha(t)$ 为 3。
3 熵正则化
文章在公式(2)中计算模型损失大小使用的方法是交叉熵,熵可用来衡量一个系统混乱程度,在概率论中,某中情况的概率越大,代表熵越小。在文章中,作者使用了熵正则化,主要思想是用熵来衡量分类的重叠程度(class overlap),熵高的时候,重叠率也是高的,熵小的时候,重叠率是低的。论文中体现的主要的半监督方法是熵正则化,熵正则化依据于低密度假设:假设数据非黑即白,在两个类别之间的数据分布之间存在比较大的差距,即在两个类别之间的边界处数据的密度很低,用熵可以来表示,作为一个正则化项。所以通过最小化未标记数据的条件熵可以减少重叠率,从而得到一个通过低密度区域的决策边界,也就是让分类的边界更加的明显。熵正则化表达式如下:
结合公式(2)和公式(4)得到最大后验估计为:
其中:
- 公式的第一部分是已标注数据的条件似然函数,对应已标注数据的损失;
- 公式的第二部分是未标注数据的条件熵,对应未标注数据的损失,可以用来减少类别的重叠率;
- $\lambda$ 对应 $\alpha(t)$。
经上述分析,可得到通过对未标注数据建立的伪标签可以达到和熵正则化一样的效果。
4 实验结果
神经网络使用 600 个已标注样本和有或没有 60000 个未标注样本和 Pseudo-Labels 进行训练的效果分别如下:
- dropNN 表示没有使用未标注数据;
- +PL 表示使用了未标注数据。
在上图中,使用已标注数据和未标注数据进行训练的模型,测试数据的网络输出的每一类会更加接近,聚类效果更好。
不同模型使用具有 100、600、1000 和 3000 个已标注训练样本在 MNIST 测试集上的分类错误率如下:
- 已标注训练样本集的大小减少到 100、600、1000 和 3000。对于验证集,分别选取 1000 个已标注样本;
- 使用相同的网络和参数进行了 10 次随机分割实验。在 100 个已标注训练样本的情况下,结果在很大程度上取决于数据分割,因此进行了30个实验。
通过上图的性能比较,尽管原文中的方法简单,但对于小型已标注数据集,其性能优于传统方法。该训练方案没有流形切线分类器复杂,并且不使用在半监督嵌入中使用的样本之间计算代价高昂的相似度矩阵。