基于分支学习和分层伪标签的行人重识别网络训练方法与流程

专利2022-05-10  51



1.本发明涉及一种行人重识别网络训练方法,尤其是涉及一种基于分支学习和分层伪标签的行人重识别网络训练方法。


背景技术:

2.行人重识别是一项跨域识别同一行人的任务,在目标自动识别中有着重要的地位。最近几年,有许多研究重点关注需要大量标注数据的全监督行人重识别,然而在生活中,大量标注数据往往会消耗大量的人力和时间成本,且在一些情境下,如刑侦调查时,往往缺少大量的标注数据,而每个行人仅有一张标注图像以供网络训练。由此引出单样本行人重识别这一具有意义的研究课题。
3.目前针对单样本的行人重识别,已经有了一些有价值的研究。有一些研究通过丰富行人的特征来致力于增加识别的精度,一些研究通过扩大训练数据集的规模来提升网络的效果,进而达到提高识别率的效果。通常,扩大训练数据集又有两个思路,一是生成新的可训练数据,二是为无标签数据赋值伪标签将其转换成标签数据参与训练。生成新数据的方法虽然能有效增大训练数据的规模,但其无法充分挖掘已有的标签数据的信息。于是伪标签法成为了应用更为广泛的半监督学习方法。伪标签法分为半监督学习的伪标签法和无监督学习的伪标签法,其中,半监督的伪标签法包括标签传播法和k近邻聚类等,无监督伪标签法包括k

means聚类和dbscan聚类等。目前已有的伪标签法大部分仅单用一种方法,然而不同的伪标签法有着不同的适用范围,能从不同视角对无标签数据赋值伪标签,仅用一种方法会限制其使用效果。更重要的是,对于大部分的伪标签法而言,伪标签数据往往被视作和标签数据具有同等的地位,并将它们混合在一起进行训练。实际上,伪标签数据的噪声导致它不能提供和标签数据一样准确的信息,而且不同伪标签法获得的伪标签数据也具有不同的噪声,因而需要将它们分组分别进行训练。不同类型的数据又具有不同的特点,因而对不同组的数据使用相同的损失函数是不合理的,需要针对不同组的特点设计个性化的损失函数。


技术实现要素:

4.本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种基于分支学习和分层伪标签的行人重识别网络训练方法。
5.本发明的目的可以通过以下技术方案来实现:
6.一种基于分支学习和分层伪标签的行人重识别网络训练方法,所述的行人重识别网络为相互平均教学网络,所述的相互平均教学网络包括两个结构相同的网络net1和net2以及对应的平均网络mean net1和mean net2,所述的训练方法包括:
7.获取标签数据集和无标签数据集,将标签数据集作为一层,将无标签数据集分为n层,并对各层的无标签数据分别赋值伪标签,形成n层伪标签数据,n为常数;
8.构建分支学习框架,包括n 1个共享权重的相互平均教学网络分支,其中一个分支
用于输入标签数据进行训练,其余n个分支分别对应输入n层伪标签数据进行训练;
9.构建各分支的损失函数,确定分支学习框架的总损失函数,基于总损失函数对分支学习框架进行多轮训练,每一轮训练过程中对无标签数据集重新进行分层。
10.优选地,所述的无标签数据集分为2层,具体为:将与标签数据集中的标签数据距离较近的若干无标签数据分作一层,剩余无标签数据分作一层。
11.优选地,所述的无标签数据集分层的具体方式为:
12.对标签数据集和无标签数据集中的标签数据和无标签数据分别采用特征提取器进行特征提取,标签数据特征记作无标签数据特征记作θ
o
为特征提取器;
13.计算无标签数据集中任意一个无标签数据和标签数据集中任意一个标签数据间的欧式距离并取最小值,计算公式为:
[0014][0015]
其中,||
·
||表示欧氏距离,l表示标签数据集;
[0016]
将无标签数据对应的由小到大排序,选取前p个无标签数据作为第一层伪标签数据,称作最近邻伪标签数据,其余无标签数据剔除掉其中的聚类离群点后作为第二层伪标签数据,称作聚类伪标签数据。
[0017]
优选地,在每一轮训练过程中更新p的大小,更新方式表示为:
[0018][0019]
其中,u表示无标签数据集中,|u|表示无标签数据集的样本个数,0<γ<1,epoch为训练轮数。
[0020]
优选地,对各层无标签数据赋值伪标签的方式为:
[0021]
对于最近邻伪标签数据,将与其欧式距离最小的有标签数据的标签作为此最近邻伪标签数据的伪标签;
[0022]
对于聚类伪标签数据,基于提取的特征对全部有标签数据和无标签数据进行聚类,将属于同一聚类类型中的有标签数据的标签作为该聚类类型中聚类伪标签数据的伪标签。
[0023]
优选地,采用dbscan聚类法对全部有标签数据和无标签数据进行聚类。
[0024]
优选地,在多轮训练过程中,所述的特征提取器θ
o
不断更新,更新方式为:
[0025]
首轮训练时,采用预设的resnet50神经网络作为该轮训练的特征提取器;
[0026]
第k轮训练时,提取k

1轮训练过的相互平均教学网络,选择net1和net2中测试指标map更高的一个网络去掉分类器作为第k轮训练的特征提取器,k≥2。
[0027]
优选地,所述的分支学习框架的总损失函数记作l,表示为:
[0028]
[0029]
其中,分别表示输入标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,分别表示输入最近邻伪标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,分别表示输入聚类伪标签数据分支的难样本三元组损失和软三元组损失,l
bd
表示输入标签数据分支的类间距离损失,l
gc
表示输入最近邻伪标签数据分支的全局中心损失,λ1,λ2,α1表示权重。
[0030]
优选地,输入标签数据分支的类间距离损失l
bd
表示为:
[0031]
l
bd
=l
bd
‑1 l
bd
‑2[0032][0033][0034]
其中,l
bd
‑1表示用于训练net1的类间距离损失,l
bd
‑2表示用于训练net2的类间距离损失,其中,l
b
表示当前训练批次的训练样本集,n
b
表示训练样本集l
b
中的样本数,和表示l
b
中的标签数据样本,和分别为输入标签数据分支中相互平均教学网络的net1和net2提取出标签数据的特征,分别为输入标签数据分支中相互平均教学网络的net1和net2提取出标签数据的特征,θ1、θ2表示net1和net2的特征提取器,||
·
||表示欧氏距离。
[0035]
优选地,输入最近邻伪标签数据分支的全局中心损失l
gc
通过如下方式获得:
[0036]
对于标签数据其对应标签为j,输入标签数据分支中的相互平均教学网络的平均网络mean net1和mean net2提取出标签数据的特征为和e
t
[θ1]、e
t
[θ2]分别为所述的平均网络mean net1和mean net2的特征提取器,将两个特征进行融合,并将融合结果记为标签j的全局类中心c
j
,其表达式为:
[0037][0038]
采用一个记忆模块来存储这些全局类中心,每完成一轮训练更新一次全局类中心的大小;
[0039]
第一轮训练时输入最近邻伪标签数据分支的全局中心损失l
gc
取作0;
[0040]
从第二轮训练开始,输入最近邻伪标签数据分支的全局中心损失l
gc
通过下式获得:
[0041][0042]
其中,表示第i个最近邻伪标签数据,n
b
表示最近邻伪标签数据的总个数,分别为输入最近邻伪标签数据分支中的相互平均教学网络中net1和net2
提取出最近邻伪标签数据的特征,θ1、θ2表示net1和net2的特征提取器,y
i
表示的伪标签。
[0043]
与现有技术相比,本发明具有如下优点:
[0044]
(1)本发明方法能够充分挖掘无标签数据的信息,为网络提供内容更丰富的训练数据,使得训练的网络更加精确;
[0045]
(2)本发明方法能有效缩短训练时网络的收敛速度。
附图说明
[0046]
图1为本发明一种基于分支学习和分层伪标签的行人重识别网络训练方法的流程示意图。
具体实施方式
[0047]
下面结合附图和具体实施例对本发明进行详细说明。注意,以下的实施方式的说明只是实质上的例示,本发明并不意在对其适用物或其用途进行限定,且本发明并不限定于以下的实施方式。
[0048]
实施例
[0049]
本实施例提供一种基于分支学习和分层伪标签的行人重识别网络训练方法,行人重识别网络为相互平均教学网络(mmt网络),相互平均教学网络为现有的网络结构为2020年发表于international conference on learning representations(iclr)中的文章“mutual mean

teaching:pseudo label refinery for unsupervised domain adaptation on person re

identification”提出的一种新型网络结构,其包含两个结构相同的网络net1和net2,以及它们对应的的平均网络mean net1和mean net2。
[0050]
如图1所示,本实施例提供的训练方法包括:
[0051]
获取标签数据集和无标签数据集,将标签数据集作为一层,将无标签数据集分为n层,并对各层的无标签数据分别赋值伪标签,形成n层伪标签数据,n为常数;
[0052]
构建分支学习框架,包括n 1个共享权重的相互平均教学网络分支,其中一个分支用于输入标签数据进行训练,其余n个分支分别对应输入n层伪标签数据进行训练;
[0053]
构建各分支的损失函数,确定分支学习框架的总损失函数,基于总损失函数对分支学习框架进行多轮训练,每一轮训练过程中对无标签数据集重新进行分层,重复训练直至网络收敛到最好结果。
[0054]
无标签数据集分为2层,具体为:将与标签数据集中的标签数据距离较近的若干无标签数据分作一层,剩余无标签数据分作一层。
[0055]
无标签数据集分层的具体方式为:
[0056]
对标签数据集和无标签数据集中的标签数据和无标签数据分别采用特征提取器进行特征提取,标签数据特征记作无标签数据特征记作θ
o
表示特征提取器;
[0057]
计算无标签数据集中任意一个无标签数据和标签数据集中任意一个标签数据
间的欧式距离并取最小值,计算公式为:
[0058][0059]
其中,||
·
||表示欧氏距离,l表示标签数据集;
[0060]
将无标签数据对应的由小到大排序,选取前p个无标签数据作为第一层伪标签数据,称作最近邻伪标签数据,其余无标签数据剔除掉其中的聚类离群点后作为第二层伪标签数据,称作聚类伪标签数据。
[0061]
在每一轮训练过程中更新p的大小,更新方式表示为:
[0062][0063]
其中,u表示无标签数据集中,|u|表示无标签数据集的样本个数,0<γ<1,epoch为训练轮数。
[0064]
对各层无标签数据赋值伪标签的方式为:
[0065]
对于最近邻伪标签数据,将与其欧式距离最小的有标签数据的标签作为此最近邻伪标签数据的伪标签;
[0066]
对于聚类伪标签数据,基于提取的特征对全部有标签数据和无标签数据进行聚类,将属于同一聚类类型中的有标签数据的标签作为该聚类类型中聚类伪标签数据的伪标签,采用dbscan聚类法对全部有标签数据和无标签数据进行聚类。
[0067]
在多轮训练过程中,特征提取器θ
o
不断更新,更新方式为:
[0068]
首轮训练时,采用预设的resnet50神经网络作为该轮训练的特征提取器;
[0069]
第k轮训练时,提取k

1轮训练过的相互平均教学网络,选择net1和net2中测试指标map更高的一个网络去掉分类器作为第k轮训练的特征提取器,k≥2。
[0070]
由此,本实施例将标签数据集和无标签数据集数据分为3层,分别为标签数据层、最近邻为标签数据层和聚类伪标签数据层,从而分支学习框架包括3个共享权重的相互平均教学网络分支。在训练过程中,不断更新最近邻为标签数据层和聚类伪标签数据层中的数据,从而使得网络识别精度越来越好。
[0071]
将标签数据层、最近邻为标签数据层和聚类伪标签数据层这三个层上的数据分别输入到不同的共享权重的mmt分支,并对每个分支用不同的损失函数进行训练。对于标签数据分支,采用分类损失、软分类损失、难样本三元组损失、软三元组损失和设计的类间距离损失进行训练;对于最近邻伪标签数据分支,同样采用了分类损失、软分类损失、难样本三元组损失、软三元组损失进行训练,并额外为其设计了全局中心损失以使训练向着缩小类间距的方向进行;对于聚类伪标签数据分支,因为这些数据的伪标签来源于聚类算法而非标签数据,其伪标签不能代表行人身份信息,因而其不能用分类损失和软分类损失而仅能用难样本三元组损失和软三元组损失进行训练。上述分类损失、软分类损失、难样本三元组损失、软三元组损失均为文章“mutual mean

teaching:pseudo label refinery for unsupervised domain adaptation on person re

identification”中提出的几种损失函数,在本实施例中不详细说明。
[0072]
对于最近邻为标签数据层,设计了类间距离损失,目的是让网络对不同的类之间
有更好的区分度,且难样本三元组损失仅学习一层中距离最近的负样本对,而忽略了其他负样本对的学习,可能会导致学习信息的丢失。类间距离损失的主要思想是由于所有的标签数据均不属于同一个类别,因而我们将标签数据在特征空间中彼此推开,输入标签数据分支的类间距离损失l
bd
表示为:
[0073]
l
bd
=l
bd
‑1 l
bd
‑2[0074][0075][0076]
其中,l
bd
‑1表示用于训练net1的类间距离损失,l
bd
‑2表示用于训练net2的类间距离损失,其中,l
b
表示当前训练批次的训练样本集,n
b
表示训练样本集l
b
中的样本数,和表示l
b
中的标签数据样本,和分别为输入标签数据分支中相互平均教学网络的net1和net2提取出标签数据的特征,分别为输入标签数据分支中相互平均教学网络的net1和net2提取出标签数据的特征,θ1、θ2表示net1和net2的特征提取器,||
·
||表示欧氏距离。
[0077]
传统的中心损失仅针对一层中的数据而不是全体训练数据,这会导致其在行人重识别上的应用受到限制。而且多分支学习框架中,上文提到的损失函数仅能学习同一层上的数据而不能学习不同层上的数据。针对这两点,设计了全局中心损失,其核心思想是使第二层的伪标签数据能紧密围绕在对应的标签数据周围。因此,输入最近邻伪标签数据分支的全局中心损失l
gc
通过如下方式获得:
[0078]
对于标签数据其对应标签为j,输入标签数据分支中的相互平均教学网络的平均网络mean net1和mean net2提取出标签数据的特征为和e
t
[θ1]、e
t
[θ2]分别为平均网络mean net1和mean net2的特征提取器,将两个特征进行融合,并将融合结果记为标签j的全局类中心c
j
,其表达式为:
[0079][0080]
采用一个记忆模块来存储这些全局类中心,每完成一轮训练更新一次全局类中心的大小;
[0081]
第一轮训练时输入最近邻伪标签数据分支的全局中心损失l
gc
取作0;
[0082]
从第二轮训练开始,输入最近邻伪标签数据分支的全局中心损失l
gc
通过下式获得:
[0083][0084]
其中,表示第i个最近邻伪标签数据,n
b
表示最近邻伪标签数据的总个数,分别为输入最近邻伪标签数据分支中的相互平均教学网络中net1
person re

identifification with one sample,in:2019ieee 31st international conference on tools with artifificial intelligence(ictai),2019.
[0096]
【5】h.li,j.xiao,m.sun,e.g.lim,y.zhao,progressive sample mining and representation learning for one

shot person re

identifification,pattern recognition 110.doi:10.1016/j.patcog.2020.107614.
[0097]
由上表可知,本发明方法能在有限的标签训练样本的条件下,充分利用了全部无标签数据信息,并专业化地对不同类别数据进行分组训练,从而训练出一个性能较好的网络以完成行人重识别任务,我们的方法比目前存在的单样本行人重识别方法更有效和先进。
[0098]
上述实施方式仅为例举,不表示对本发明范围的限定。这些实施方式还能以其它各种方式来实施,且能在不脱离本发明技术思想的范围内作各种省略、置换、变更。
转载请注明原文地址:https://doc.8miu.com/read-1719393.html

最新回复(0)