知行编程网知行编程网  2022-07-06 18:00 知行编程网 隐藏边栏 |   抢沙发  17 
文章评分 0 次,平均分 0.0

NLP样本不均衡之常用损失函数对比(附代码)

来自 | 知乎   作者 | 青空栀浅
链接 | https://zhuanlan.zhihu.com/p/191355122
编辑 | 深度学习这件小事
本文仅作学术交流,如有侵权,请联系后台删除。

   0 前言

本文分为三个部分,第一个部分主要介绍一下在分类问题中为什么用交叉熵作为损失函数,第二部分主要介绍一下在交叉熵的基础上的一些改进的损失函数,最后使用上述的几种损失函数在 CLUENER 细粒度命名实体识别的数据集上进行效果的效果对比。

   1 为什么使用交叉熵作为损失函数?

交叉熵作为最常用的损失函数想必再熟悉不过了,本节简单介绍一下交叉熵作为loss的优势。

1.1 为什么不用classification error?


下面通过两个例子来说明一下:

假设下面为两个模型的结果,其中predict为模型的预测值,target为真实的分类。

model_1:



model_2:



通过上面两个表的对比,可以看出模型2的效果要明显优于模型1,因为在正确的分类上模型的预测结果都显著大于其他类别。但是从classification error上发现两个模型的效果是一样的,发现问题了吧,classification error不能准确的描述模型与理想的模型的距离。

1.2 为什么不用Mean Squared Error


model_1:


model_2:


MSE 看起来也不错,不用它的原因是因为:

如果用 MSE 计算 loss, 其损失函数是一个非凸函数,存在很多局部极值点并且MSE更容易导致梯度弥散。

关于这个问题,已经有很多回答了,这里就不展开说了。

NLP样本不均衡之常用损失函数对比(附代码)

1.3 使用交叉熵的损失

NLP样本不均衡之常用损失函数对比(附代码)

model_1:


model_2:


cross-entropy 更清晰的描述了模型与理想模型的距离,并且CE作为loss其也是一个凸函数。


   交叉熵的基础上的一些改进的损失函数


2.1 Label Smoothing


在这个公式中,  表示  的标准交叉熵损失函数,  是一个非常小的正数,  表示对应的正确分类,  为所有分类的数量。

直观上看,标记平滑限制了正确类的  值,并使得它更接近于其他类的  值。从而在一定程度上,它被当作为一种正则化技术和一种对抗模型过度自信的方法。

2.2 Focal Loss

Focal Loss的引入主要是为了解决难易样本数量不平衡(注意,有区别于正负样本数量不平衡)的问题,这里将样本分为四个类型

NLP样本不均衡之常用损失函数对比(附代码)

我们在计算分类的时候常用的损失——交叉熵的公式如下:

NLP样本不均衡之常用损失函数对比(附代码)

为了解决正负样本不平衡的问题,我们通常会在交叉熵损失的前面加上一个参数  ,即:

NLP样本不均衡之常用损失函数对比(附代码)

尽管  平衡了正负样本,但对难易样本的不平衡没有任何帮助。因为易分样本虽然损失很低,但是数量太多,导致最终主导了损失。而作者认为,易分样本(即,置信度高的样本)对模型的提升效果非常小,模型应该主要关注与那些难分样本(这个假设是有问题的,是GHM的主要改进对象)。一个简单的思想:把高置信度(p)样本的损失再降低一些不就好了,即:

NLP样本不均衡之常用损失函数对比(附代码)

举个例子  ,如果  ,那么  ,损失衰减了1000倍!

Focal Loss的最终形式如下:

NLP样本不均衡之常用损失函数对比(附代码)

即结合了(2)和(3)式,简单明了。

2.3 GHM_C

GHM_C是在Focal的基础上的进一步改进,首先看一下Focal存在的问题:

  • 首先,让模型过度关注特别难分的样本其实是有问题的,因为那种样本点很可能本身就是离群点或者是标注错误等造成的样本。


  • 其次,Focal中的  与  是根据实验得出来的,并且 与 的值也会相互影响,不同数据需要不断尝试调整 与 


研究者对样本不均衡的本质影响进行了进一步探讨,找到了梯度分布这个更为深入的角度。认为样本不均衡的本质是分布的不均衡。更进一步来看,每个样本对模型训练的实质作用是产生一个梯度用以更新模型的参数,不同样本对参数更新会产生不同的贡献。由于简单样本的数量非常大,它们产生的累计贡献就在模型更新中就会有巨大的影响力甚至占据主导作用,而由于它们本身已经被模型很好的判别,所以这部分的参数更新并不会改善模型的判断能力,也就使整个训练变得低效。

基于这一点,研究者对样本梯度的分布进行了统计,并根据这个分布设计了一个梯度均衡机制,使得模型训练更加高效与稳健,并可以收敛到更好的结果。

首先定义一个统计对象——梯度模长(gradient norm)。考虑简单的二分类交叉熵损失函数:

NLP样本不均衡之常用损失函数对比(附代码)

其中  为模型所预测的样本类别的概率,  是真实类别[0,1]。则其对  的梯度(导数)为:

NLP样本不均衡之常用损失函数对比(附代码)

于是定义梯度模长,g:

NLP样本不均衡之常用损失函数对比(附代码)

下图为论文中展示的梯度模长与样本量的关系,可以看出绝大部分的样本集中于梯度模长较小的区域内(图中左侧),也就是所谓的易分样本。

NLP样本不均衡之常用损失函数对比(附代码)

那怎么同时衰减易分样本和特别难分的样本呢?太简单了,谁的数量多衰减谁呗!那怎么衰减数量多的呢?简单啊,定义一个变量,让这个变量能衡量出一定梯度范围内的样本数量——这不就是物理上密度的概念吗?

接下来定义梯度密度:

于是,作者定义了梯度密度  ——本文最重要的公式:

NLP样本不均衡之常用损失函数对比(附代码)

其中:

NLP样本不均衡之常用损失函数对比(附代码)

 就是在样本N中,梯度模长分布在  范围内的样本数。  代表了  的区间长度。

因此梯度密度的物理含义是:单位梯度模长g部分的样本个数。

接下来梯度均衡的参数定义为:
NLP样本不均衡之常用损失函数对比(附代码)

其中的  就是所有的样本数。

最终GHM_C的定义形式就是:

NLP样本不均衡之常用损失函数对比(附代码)

接下来看一下该损失函数的效果:

NLP样本不均衡之常用损失函数对比(附代码)

左边的图的意思就是不同损失函数对样本梯度产生的作用,横坐标为在交叉熵(CE)损失函数下样本的梯度模长,纵坐标为新的损失函数下同样的样本新的梯度模长,可以看出FL作为损失函数的情况下对于易分样本的确是存在抑制作用的,GHM-C作为损失函数不仅对易分样本有抑制,并且对特别难分的样本也是存在抑制作用的。右图中也可以看出GHM_C在易分样本上有相对的抑制,对于其他样本的梯度较为均衡。

   不同损失函数的评测


在 CLUENER 细粒度命名实体识别的数据集上对上文提到的损失函数进行效果的对比。
因为命名实体识别的的模型相对于其他模型的数据样本就是很不平衡的。

数据集


数据集地址:
https://www.cluebenchmarks.com/introduce.html


各标签分布如下:

NLP样本不均衡之常用损失函数对比(附代码)
NLP样本不均衡之常用损失函数对比(附代码)

在10种类型中,book,movie,scene的占比相对较少。

模型


模型使用的是Bert+two_pointer的方式,即预测实体的start-position和end-position的方式。

这种形式下构建的数据集的标签也是很不均衡的。例如:


可以看出其实大部分的标签是0,就是非实体的意思。

使用该模型对以上的四种损失函数的测试结果如下:

NLP样本不均衡之常用损失函数对比(附代码)

表格中前几列为各个标签的F1值,最后一列为总体的F1值。

下图为四种损失函数的变化趋势:

NLP样本不均衡之常用损失函数对比(附代码)
NLP样本不均衡之常用损失函数对比(附代码)

从loss曲线来看,四个损失函数下模型已经收敛,从F1效果看整体的效果是:

lsr的效果最好,其次是focal,然后是ghmc最后是ce。

也可以看出四个损失函数下其实模型的差距并不显著。

相关代码已经上传,详见我的github:
https://github.com/qingkongzhiqian/NER_loss_compare

   ToDo

目前只是测试四种损失函数在 CLUENER 细粒度命名实体识别的数据集上的效果,还没有统计该数据集的梯度模长的分布,后续添加吧!

   总结

本文首先介绍了为什么用交叉熵作为损失函数,然后介绍了一些常用的损失函数,最后使用上述的几种损失函数在 CLUENER 细粒度命名实体识别的数据集上进行效果的效果对比。

参考链接:
https://zhuanlan.zhihu.com/p/80594704
https://zhuanlan.zhihu.com/p/80594704
https://github.com/CLUEbenchmark/CLUENER2020
https://www.cnblogs.com/maybe2030/p/9163479.html


<section data-brushtype="text" style="padding-right: 0em;padding-left: 0em;white-space: normal;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-family: "Helvetica Neue", Helvetica, "Hiragino Sans GB", "Microsoft YaHei", Arial, sans-serif;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><pre style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">干货 | 算法工程师超实用技术路线图</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">13个算法工程师必须掌握的PyTorch Tricks</span><span style="letter-spacing: 0.544px;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;"></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">吴恩达上新:生成对抗网络(GAN)专项课程</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">拿到2021灰飞烟灭算法岗offer的大佬们是啥样的<span style="font-size: 14px;">?</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">你一定从未看过如此通俗易懂的YOLO系列解读 (下)</section></section></section></section></section></section></section></section></section>

NLP样本不均衡之常用损失函数对比(附代码)

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享