在机器学习中,更多数据会带来更好的结果。但标记数据可能既昂贵又耗时。如果我们可以使用通常很容易获得的大量未标记数据怎么办?这就是伪标签派上用场的地方。
TL;DR:我对 MNIST 数据集进行了案例研究,并通过应用基于置信度的迭代伪标签将模型准确率从 90% 提高到 95%。本文介绍了伪标签的详细信息,以及我的实验中的实用技巧和见解。
伪标签是半监督学习的一种。它弥合了监督学习(所有数据都被标记)和无监督学习(没有数据被标记)之间的差距。
我遵循的具体程序如下:
虽然这种方法可能会引入一些不正确的标签,但其好处来自于训练数据量的显着增加。
模型从自己的预测中学习的想法可能会引起一些人的注意。毕竟,我们不是试图依靠回声室来从无到有地创造一些东西,而模型只是简单地强化了它自己的初始偏差和错误吗?
这种担忧是有道理的。这可能会让您想起传奇人物姆希豪森男爵(Baron Mnchhausen),他曾声称用自己的头发将自己和他的马从沼泽中拉了出来,这在物理上是不可能的。同样,如果一个模型仅仅依赖于自己可能有缺陷的预测,它就有可能陷入自我强化的循环,就像被困在回音室中的人只能听到自己的信念反射回来一样。
那么,伪标签真的能有效而不陷入这个陷阱吗?
答案是肯定的。虽然姆希豪森男爵的故事显然是一个童话故事,但你可以想象一个铁匠在各个时代的进步。他从基本的石器开始(最初的标记数据)。利用这些,他从原始矿石(未标记数据)中锻造了粗铜工具(伪标签)。这些铜工具虽然仍处于初级阶段,但使他能够完成以前无法完成的任务,最终创造出由青铜、铁等制成的工具。这个迭代过程至关重要:你无法使用石锤锻造钢剑。
就像铁匠一样,在机器学习中,我们可以通过以下方式实现类似的进展:
如果做得正确,伪标签可以成为充分利用小型标记数据集的强大工具,正如我们将在下面的案例研究中看到的那样。
我在 MNIST 数据集上进行了实验,该数据集是 28 x 28 像素手写数字图像的经典集合,广泛用于机器学习模型的基准测试。它由 60,000 张训练图像和 10,000 张测试图像组成。目标是根据 28 x 28 像素预测写入的数字。
我在初始的 1,000 张标记图像上训练了一个简单的 CNN,剩下 59,000 张未标记图像。然后,我使用经过训练的模型来预测未标记图像的标签。置信度高于特定阈值(例如 95%)的预测及其预测标签被添加到训练集中。然后在这个扩展的数据集上重新训练模型。这个过程反复重复,最多十次,或者直到不再有未标记的数据。
使用不同数量的初始标记图像和置信度阈值重复此实验。
下表总结了我的实验结果,比较了伪标记与在完整标记数据集上进行训练的性能。
即使初始标记数据集很小,伪标记也可能产生显着的结果,将准确度提高 4.87%pt。1,000 个初始标记样本。当仅使用 100 个初始样本时,这种效果甚至更强。然而,手动标记 100 个以上的样本是明智的。
有趣的是,100个初始训练样本的实验的最终测试精度超过了正确训练标签的份额。
从上图可以明显看出,一般来说,只要至少有一些预测超过阈值,较高的阈值就会带来更好的结果。在未来的实验中,人们可能会尝试在每次迭代中改变阈值。
此外,即使在后期迭代中,准确性也会提高,这表明迭代性质提供了真正的好处。
可以在此处找到包含实验代码的存储库。
相关论文:具有深度特征注释和基于置信度采样的迭代伪标签