一种基于多标签对图像分类模型训练的方法

专利2026-02-21  7


本发明涉及神经网络领域,具体来说涉及图像分类领域,更具体地说,涉及一种基于多标签对图像分类模型训练的方法。


背景技术:

1、传统的分类方法主要关注于分布内(in-distribution,id)场景,即训练集和测试集共享相同的数据分布。然而,在实际应用中,不可避免地会有一些测试样本属于训练样本分布之外,这些样本被称为分布外(out-of-distribution,ood)样本。在这种情况下,一个id模型会由于过度自信问题,将ood样本错误分类为某个id类别,进而导致模型性能的严重下降。因此,近年来,ood检测任务引起了越来越多的关注,旨在合理甄别ood样本和id样本。

2、在该研究方向上,大多数先前的研究关注于多类别场景,其中只有真实标签主导ood推理过程。然而,在许多实际应用中,例如自动驾驶和医学图像处理,样本本质上与多个标签相关联。对于多标签ood检测,一种直观的策略是适应现成的多类别算法。然而,对于多标签任务,联合建模不同标签间的信息至关重要。受此启发,一些研究者在多标签ood检测方面进行了开创性的工作,指出仅使用深度神经网络输出可能并不是最佳选择。为了解决这个问题,他们建议将所有标签的能量分数总和作为推理标准(称为jointenergy),其中能量分数与对数输出(logit)成正比。

3、尽管jointenergy对ood推理有效,但发明人发现该方法引发了一个新的不平衡问题:在多标签分布外样本检测任务中,由于分布内样本与分布外样本之间的能量间距较小,并且头部样本和尾部样本之间巨大的样本数量差异,既往方法只充分学习到了头部样本信息,而面对尾部样本时总会产生严重的性能下降;故而jointenergy可能会促进“马太效应”,也即,头部样本往往具有较大的logit值,更有可能被识别为id样本,相比之下,尾部样本通常具有较低的logit值,而被错误地分类为ood样本;其中,尾部意味着相关样本都属于少数类别,头部意味着样本至少与一个多数类别相关,少数/多数指的是类别与少/多样本相关联。

4、需要说明的是:本背景技术仅用于介绍本发明的相关信息,以便于帮助理解本发明的技术方案,但并不意味着相关信息必然是现有技术。在没有证据表明相关信息已在本发明的申请日以前公开的情况下,相关信息不应被视为现有技术。


技术实现思路

1、因此,本发明的目的在于克服上述现有技术的缺陷,提供一种基于多标签对图像分类模型训练的方法。

2、本发明的目的是通过以下技术方案实现的:

3、根据本发明的第一方面,提供了一种基于多标签对图像分类模型训练的方法,所述方法包括:获取训练集,所述训练集包括多个分布内图像样本和多个分布外图像样本,以及指示各个分布内图像样本所属类别的标签,其中,所述分布外图像样本为不属于分布内图像样本所属任意类别的图像样本;利用所述训练集对所述图像分类模型进行一次或者多次训练,其中,利用图像分类模型提取输入图像在各个类别上的logit值和置信度,每次训练时,基于预设的总损失函数确定的总损失更新图像分类模型的参数,得到经训练的图像分类模型,所述总损失根据以下损失加权求和确定:分布内损失、分布外损失和能量分布差距损失;其中,所述分布内损失根据图像分类模型输出的分布内图像样本在各个类别上的置信度以及对应的标签进行确定;所述分布外损失根据图像分类模型输出的分布外图像样本在各个类别上的置信度进行确定;所述能量分布差距损失根据所有分布外图像样本的能量分数的均值以及低于预设能量分数阈值的多个分布内图像样本的能量分数的均值进行确定,其中,所述能量分数根据图像样本在各个类别上的logit值进行确定。

4、在本发明的一些实施例中,所述分布内损失是图像分类模型输出的分布内图像样本在各个类别上的置信度以及对应的标签计算的交叉熵损失;所述分布外损失被配置为:与各个分布外图像样本在所有类别上的置信度正相关;所述能量分布差距损失被配置为:与所有分布外图像样本的能量分数的均值正相关,与低于预设能量分数阈值的多个分布内图像样本的能量分数的均值负相关。

5、在本发明的一些实施例中,所述总损失为:

6、

7、其中,表示分布内损失,表示分布外损失,表示能量分布差距损失,λ、α和β分别表示分布内损失、分布外损失和能量分布差距损失对应的权值。

8、在本发明的一些实施例中,所述分布内损失为:

9、

10、其中,表示分布内图像样本的集合,p表示该集合中分布内图像样本的总数,xi表示第i个分布内图像样本,c表示类别的总数,yij表示第i个分布内图像样本在第j类上对应的标签,σj(xi)表示第i个分布内图像样本在第j类上的置信度。

11、在本发明的一些实施例中,所述分布外损失为:

12、

13、其中,表示分布外图像样本的集合,n表示该集合中分布外图像样本的总数,x′i表示第i个分布外图像样本,c表示类别的总数,σ(x′i)表示第i个分布外图像样本在第j类上的置信度。

14、在本发明的一些实施例中,所述能量分布差距损失为:

15、

16、其中,ε(x′)表示所有分布外图像样本的能量分数,表示低于预设能量分数阈值的k个分布内图像样本的能量分数的均值,m表示超参数;

17、所述能量分数按以下方式计算得到:

18、

19、其中,x表示图像样本,c表示类别,fj(x)表示图像样本在第j类上的logit值。

20、在本发明的一些实施例中,所述分布外样本按照以下方式进行筛选得到:获取数据集,所述数据集包括多个分布内图像样本以及多个候选分布外图像样本;获取图像分类模型,所述图像分类模型包括用于提取图像特征的特征提取器以及用于根据图像特征提取图像在各个类别上的logit值的全连接层;利用所述图像分类模型提取所述数据集中各个图像样本在各个类别上的倒数第二层特征,并确定该特征的奇异值对角矩阵,其中,所述倒数第二层特征为提取logit值的全连接层的前一层提取得到的特征;根据各个图像样本的奇异值对角矩阵计算分布内图像样本与分布外图像样本的特征相似度,并从候选分布外图像样本中选取与分布内图像样本的特征相似度较大的图像样本作为所述训练集中的分布外图像样本。

21、根据本发明的第二方面,提供了一种图像分类的方法,所述方法包括:获取如本发明的第一方面得到的经训练的图像分类模型;获取待分类图像;利用所述图像分类模型提取所述待分类图像各个类别上的logit值,根据所述logit值判断所述待分类图像是否为分布外图像;若是,输出未知类别;若否,根据所述待分类图像各个类别上的logit值确定所述待分类图像所属的类别。

22、与现有技术相比,本发明的优点在于:

23、1)通过分布内损失、分布外损失以及能量分布差距损失加权求和的总损失更新模型的参数,使得模型能够更好地拟合分布内样本,提高模型对分布内样本识别的准确性,同时保持模型对分布外样本的泛化能力,还能够扩大分布内样本与分布外样本之间的能量分布差距,确定分布内与分布外之间的决策边界,以学习整体数据的分布,进而能够更好地区别分布内样本和分布外样本,提高模型的识别能力。

24、2)通过筛选算法从候选分布外图像样本中选取与分布内图像样本之间的特征相似度较大的图像样本,能够在一定程度上筛选出更有利于提成模型分类能力的训练样本,使得训练后的模型能够更好的学习分布内样本与分布外样本之间的决策边界,进而能够更好地识别分布内样本和分布外样本,提高模型的识别性能。


技术特征:

1.一种基于多标签对图像分类模型训练的方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,

3.根据权利要求2所述的方法,其特征在于,所述总损失为:

4.根据权利要求3所述的方法,其特征在于,所述分布内损失为:

5.根据权利要求4所述的方法,其特征在于,所述分布外损失为:

6.根据权利要求5所述的方法,其特征在于,所述能量分布差距损失为:

7.根据权利要求1-6之一所述的方法,其特征在于,所述分布外样本按照以下方式进行筛选得到:

8.一种图像分类的方法,其特征在于,所述方法包括:

9.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,所述计算机程序可被处理器执行以实现权利要求1至8中任一项所述方法的步骤。

10.一种电子设备,其特征在于,包括:


技术总结
本发明提供了一种基于多标签对图像分类模型训练的方法,所述方法包括:获取训练集,所述训练集包括多个分布内图像样本和多个分布外图像样本,以及指示各个分布内图像样本所属类别的标签,其中,所述分布外图像样本为不属于分布内图像样本所属任意类别的图像样本;利用训练集对图像分类模型进行一次或者多次训练,其中,利用图像分类模型提取输入图像在各个类别上的logit值和置信度,每次训练时,基于预设的总损失函数确定的总损失更新图像分类模型的参数,得到经训练的图像分类模型,所述总损失根据以下损失加权求和确定:分布内损失、分布外损失和能量分布差距损失。

技术研发人员:许倩倩,孙宇辰,王子泰,何俊伟,杨智勇,黄庆明
受保护的技术使用者:中国科学院计算技术研究所
技术研发日:
技术公布日:2024/6/26
转载请注明原文地址:https://doc.8miu.com/read-1827932.html

最新回复(0)