教你的模型自学 |作者:尼克拉斯·冯·莫尔斯 |2024 年 9 月 - 迈向数据科学

2024-09-16 19:42:47 英文原文

<正文>

教你的模型自学

基于迭代、基于置信度的伪标签分类的案例研究

在机器学习中,更多数据会带来更好的结果。但标记数据可能既昂贵又耗时。如果我们可以使用通常很容易获得的大量未标记数据怎么办?这就是伪标签派上用场的地方。

TL;DR:我对 MNIST 数据集进行了案例研究,并通过应用基于置信度的迭代伪标签将模型准确率从 90% 提高到 95%。本文介绍了伪标签的详细信息,以及我的实验中的实用技巧和见解。

它是如何工作的?

伪标签是半监督学习的一种。它弥合了监督学习(所有数据都被标记)和无监督学习(没有数据被标记)之间的差距。

我遵循的具体程序如下:

  • 我们从少量标记数据开始,并用其训练我们的模型。
  • 模型对未标记的数据进行预测。
  • 我们选择模型最有信心的预测(例如,置信度高于 95%),并将它们视为实际标签,希望它们足够可靠。
  • 我们将此伪标记数据添加到我们的训练集中并重新训练模型。
  • 我们可以多次重复此过程,让模型从不断增长的伪标记数据池中学习。

虽然这种方法可能会引入一些不正确的标签,但其好处来自于训练数据量的显着增加。

回音室效应:伪标签还能起作用吗?

模型从自己的预测中学习的想法可能会引起一些人的注意。毕竟,我们不是试图依靠回声室来从无到有地创造一些东西,而模型只是简单地强化了它自己的初始偏差和错误吗?

这种担忧是有道理的。这可能会让您想起传奇人物姆希豪森男爵(Baron Mnchhausen),他曾声称用自己的头发将自己和他的马从沼泽中拉了出来,这在物理上是不可能的。同样,如果一个模型仅仅依赖于自己可能有缺陷的预测,它就有可能陷入自我强化的循环,就像被困在回音室中的人只能听到自己的信念反射回来一样。

那么,伪标签真的能有效而不陷入这个陷阱吗?

答案是肯定的。虽然姆希豪森男爵的故事显然是一个童话故事,但你可以想象一个铁匠在各个时代的进步。他从基本的石器开始(最初的标记数据)。利用这些,他从原始矿石(未标记数据)中锻造了粗铜工具(伪标签)。这些铜工具虽然仍处于初级阶段,但使他能够完成以前无法完成的任务,最终创造出由青铜、铁等制成的工具。这个迭代过程至关重要:你无法使用石锤锻造钢剑。

就像铁匠一样,在机器学习中,我们可以通过以下方式实现类似的进展:

  • 严格的阈值:模型的样本外准确度受到正确训练标签的份额的限制。如果 10% 的标签错误,模型的准确率不会显着超过 90%。因此,尽可能少地允许错误标签非常重要。
  • 可衡量的反馈:在单独的测试集上不断评估模型性能,作为现实检查,确保取得实际进展,而不仅仅是强化现有错误。
  • 人机交互:以人工审核伪标签或人工标记低置信度数据的形式纳入人工反馈,可以提供有价值的过程修正。

如果做得正确,伪标签可以成为充分利用小型标记数据集的强大工具,正如我们将在下面的案例研究中看到的那样。

案例研究:MNIST 数据集

我在 MNIST 数据集上进行了实验,该数据集是 28 x 28 像素手写数字图像的经典集合,广泛用于机器学习模型的基准测试。它由 60,000 张训练图像和 10,000 张测试图像组成。目标是根据 28 x 28 像素预测写入的数字。

我在初始的 1,000 张标记图像上训练了一个简单的 CNN,剩下 59,000 张未标记图像。然后,我使用经过训练的模型来预测未标记图像的标签。置信度高于特定阈值(例如 95%)的预测及其预测标签被添加到训练集中。然后在这个扩展的数据集上重新训练模型。这个过程反复重复,最多十次,或者直到不再有未标记的数据。

使用不同数量的初始标记图像和置信度阈值重复此实验。

结果

下表总结了我的实验结果,比较了伪标记与在完整标记数据集上进行训练的性能。

即使初始标记数据集很小,伪标记也可能产生显着的结果,将准确度提高 4.87%pt。1,000 个初始标记样本。当仅使用 100 个初始样本时,这种效果甚至更强。然而,手动标记 100 个以上的样本是明智的。

有趣的是,100个初始训练样本的实验的最终测试精度超过了正确训练标签的份额。

从上图可以明显看出,一般来说,只要至少有一些预测超过阈值,较高的阈值就会带来更好的结果。在未来的实验中,人们可能会尝试在每次迭代中改变阈值。

此外,即使在后期迭代中,准确性也会提高,这表明迭代性质提供了真正的好处。

主要发现和经验教训

  • 当未标记数据大量但标记成本昂贵时,最好采用伪标记。
  • 监控测试准确性:在整个迭代过程中密切关注单独测试数据集上的模型性能非常重要。
  • 手动标记仍然很有帮助:如果您有资源,请重点关注手动标记低置信度数据。然而,人类也不是完美的,高可信度数据的标签可以凭良心委托给模型。
  • 跟踪人工智能生成的标签。如果稍后有更多手动标记的数据可用,您可能希望丢弃伪标签并重复此过程,以提高伪标签的准确性。
  • 解释结果时要小心:几年前我第一次做这个实验时,我关注的是剩余未标记训练数据的准确性。随着迭代次数的增加,精度会下降!然而,这可能是因为剩余的数据更难预测,模型在之前的迭代中从未对此充满信心。我应该关注测试的准确性,它实际上会随着更多的迭代而提高。

链接

可以在此处找到包含实验代码的存储库。

相关论文:具有深度特征注释和基于置信度采样的迭代伪标签

摘要

教你的模型自学基于迭代、基于置信度的伪标记分类的案例研究在机器学习中,更多的数据会带来更好的结果。这就是伪标签派上用场的地方。TL;DR:我对 MNIST 数据集进行了案例研究,并通过应用基于置信度的迭代伪标记将模型的准确率从 90% 提高到 95%。我们可以多次重复这个过程,让模型从不断增长的伪标签数据池中学习。然而,手动标记 100 多个样本是明智之举。