1.本公开涉及深度学习技术领域,尤其涉及模型训练领域,具体涉及一种模型训练的方法、对象分类方法、装置及电子设备。
背景技术:
2.所谓多任务模型,是指能够针对对象同时实现多个分类任务的模型。例如,在图像分类时,通过多任务模型,可以针对一张图像同时实现多种图像分类。
3.相关技术中,主要通过硬参数共享的方式训练得到多任务模型。然而,采用硬参数共享的方式需要训练的多任务模型的网络参数较多。
技术实现要素:
4.本公开提供了一种用于减少多任务模型内需要训练的网络参数的模型训练的方法、对象分类方法、装置及电子设备。
5.根据本公开的一方面,提供了一种模型训练的方法,包括:
6.获取用于模型训练的样本集;其中,所述样本集包含多任务模型的每一分类任务的样本对象,所述多任务模型包含第一类网络和第二类网络,所述第一类网络包含全连接层,以及末端的特征提取层中的归一化层;所述第二类网络为除所述第一类网络以外的网络结构;
7.利用所述样本集所包括的各样本对象,以及各样本对象对应的任务标识,对所述多任务模型进行训练;
8.其中,每一样本对象对应的任务标识为该样本对象所属分类任务的标识;在训练过程中,每一样本对象用于训练所述第二类网络的网络参数,以及所述第一类网络的针对相应任务的网络参数,所述相应任务为具有该样本对象对应的任务标识的分类任务。
9.根据本公开的另一方面,提供了一种对象分类方法,包括:
10.获取待分类的目标对象;
11.基于预先训练的目标多任务模型,对所述目标对象进行多任务分类,得到各个分类任务的分类结果;
12.其中,所述目标多任务分类模型为利用任一项所述的模型训练的方法所训练得到的模型。
13.根据本公开的另一方面,提供了一种模型训练的装置,包括:
14.样本集获取模块,用于获取用于模型训练的样本集;其中,所述样本集包含多任务模型的每一分类任务的样本对象,所述多任务模型包含第一类网络和第二类网络,所述第一类网络包含全连接层,以及末端的特征提取层中的归一化层;所述第二类网络为除所述第一类网络以外的网络结构;
15.模型训练模块,用于利用所述样本集所包括的各样本对象,以及各样本对象对应的任务标识,对所述多任务模型进行训练;
16.其中,每一样本对象对应的任务标识为该样本对象所属分类任务的标识;在训练过程中,每一样本对象用于训练所述第二类网络的网络参数,以及所述第一类网络的针对相应任务的网络参数,所述相应任务为具有该样本对象对应的任务标识的分类任务。
17.根据本公开的另一方面,提供了一种对象分类装置,包括:
18.对象获取模块,用于获取待分类的目标对象;
19.对象分类模块,用于基于预先训练的目标多任务模型,对所述目标对象进行多任务分类,得到各个分类任务的分类结果;
20.其中,所述目标多任务分类模型为利用任一项所述的模型训练的装置所训练得到的模型。
21.根据本公开的另一方面,提供了一种电子设备,包括:
22.至少一个处理器;以及
23.与所述至少一个处理器通信连接的存储器;其中,
24.所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行模型训练的方法或对象分类方法。
25.根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行模型训练的方法或对象分类方法。
26.根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现模型训练的方法或对象分类方法。
27.应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
28.附图用于更好地理解本方案,不构成对本公开的限定。其中:
29.图1是相关技术中采用硬参数共享的方式所训练的多任务模型的示意图;
30.图2是相关技术中resnet50网络的示意图;
31.图3是根据本公开第一实施例的示意图;
32.图4是根据本公开第二实施例的示意图;
33.图5是根据本公开第三实施例的示意图;
34.图6是根据本公开实施例提供的多任务模型的示意图;
35.图7是根据本公开第四实施例的示意图;
36.图8是根据本公开第五实施例的示意图;
37.图9是用来实现本公开实施例的模型训练的方法或对象分类方法的电子设备的框图。
具体实施方式
38.以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
39.随着深度学习地不断发展,目前深度学习中对象分类也是成为很有重要的算法任务。
40.在实际需求中,往往需要执行多个并行无关联的分类任务。如果一个分类任务均需要对应一个对象分类模型,那么在训练和部署这些对象分类模型时,工作量将会大大增加且部署效率也较低。
41.而多任务模型成为解决这些问题的有效手段,其中,多任务模型是指能够针对对象同时实现多个分类任务的模型。例如,在图像分类时,通过多任务模型,可以针对一张图像同时实现多种图像分类。
42.相关技术中,主要通过硬参数共享的方式训练得到多任务模型。简单而言,多任务模型中,各分类任务共享多任务模型中底层的特征提取层,同时每一分类任务具有独立的末端的特征提取层,以及独立的全连接层。如图1所示,为相关技术中采用硬参数共享的方式所训练的多任务模型的示意图。图1中task(任务)a、task b、task c三个任务共享底层的特征提取层,即图1中的shared layes(共享层),同时每个任务在高层上有各自的末端的特征提取层和全连接层,即图中的specific layes(独有层),构成一种底层共享,高层独立的结构。
43.示例性的,如图2所示,为相关技术中resnet50网络的示意图。其中:input(输入)是输入层,用于输入待处理的对象;stem layer(茎层)是初始层,包含一个卷积层和池化层;layer1、layer2、layer3和layer4是resnet50主要结构,其中,layer4层表示resnet50的倒数3个卷积层结构组成。每个layer层包含若干个bottleneck(瓶颈)结构,bottleneck结构入如图2中的右半部分所示,每个bottleneck结构由三个conv bn relu层构成,其中,conv表示卷积层,bn表示批量归一化标准化,即归一化层,relu是一种激活函数,即激活函数层,表达式为f(x)=max(0,x);gap(global average pooling,全局平均池化)表示全局池化层,输出固定2048维的向量;fc为全连接层。
44.如果使用经典的resnet50作为多任务模型的基础网络,那么采用硬参数共享的方式为:任务a、任务b和任务c共用resnet50网络中底层的input、stem layer、layer1层、layer2层和layer3层,而最后的layer4层、gap层和全连接层是每个任务独自拥有。
45.然而,由于对于神经网络模型而言,末端的特征提取层往往包含较多的网络参数,而对于多任务分类模型而言,每个分类任务均具有独立的末端的特征提取层,使得采用硬参数共享的方式需要训练的网络参数较多。进而导致训练时所占用的内存较多。同时,在对训练完成后的模型进行部署时,需要占用较多的内存。
46.相关技术中,还存在一种各分类任务共享所有的特征提取层,且每种分类任务仅具有独立的全连接层的方案。然而,因为所有的特征提取层的网络参数均是共享的,使得在某一个分类任务达到最佳分类时,其他分类任务无法达到最佳分类。即出现无法同时收敛的问题。
47.为了减少解决相关技术中上述所存在的技术问题,本公开实施例提供一种模型训练的方法。
48.需要说明的,在具体应用中,本公开实施例所提供的模型训练的方法可以应用于各类电子设备,例如,个人电脑、服务器、以及其他具有数据处理能力的设备。另外,可以理解的是,本公开实施例提供的模型训练方法可以通过软件、硬件或软硬件结合的方式实现。
49.其中,本公开实施例所提供的一种模型训练的方法可以包括:
50.获取用于模型训练的样本集;其中,样本集包含多任务模型的每一分类任务的样本对象,多任务模型包含第一类网络和第二类网络,第一类网络包含全连接层,以及末端的特征提取层中的归一化层;第二类网络为除第一类网络以外的网络结构;
51.利用样本集所包括的各样本对象,以及各样本对象对应的任务标识,对多任务模型进行训练;
52.其中,每一样本对象对应的任务标识为该样本对象所属分类任务的标识;在训练过程中,每一样本对象用于训练第二类网络的网络参数,以及第一类网络的针对相应任务的网络参数,相应任务为具有该样本对象对应的任务标识的分类任务。
53.本公开提供的上述方案,由于在训练过程中,每一样本对象用于训练第二类网络的网络参数,以及第一类网络的针对具有该样本对象对应的任务标识的分类任务的网络参数,使得训练后的多任务模型中,第二类网络的网络参数是各分类任务所共享的,第一类网络中各网络参数是分属于各分类任务的。又因为第一类网络仅包含末端的特征提取层中的归一化层,故而末端的特征提取层中除归一化层外的网络参数是各分类任务所共享的,这样,减少了每一分类任务独有的网络参数的数量。可见,通过本公开提供的上述方案可以减少多任务模型内需要训练的网络参数。进而可以减少训练与部署多任务模型时占用的内存。
54.进一步的,由于各分类任务在末端的特征提取层中具有独有的归一化层内的参数,从而可以避免各分类任务之间出现训练冲突的问题。
55.下面结合附图对本公开实施例所提供的一种模型训练的方法进行介绍。
56.如图3所示,本公开实施例提供所一种模型训练的方法,可以包括如下步骤:
57.s301,获取用于模型训练的样本集;其中,样本集包含多任务模型的每一分类任务的样本对象,多任务模型包含第一类网络和第二类网络,第一类网络包含全连接层,以及末端的特征提取层中的归一化层;第二类网络为除第一类网络以外的网络结构;
58.其中,多任务模型中第一类网络内的网络参数分属于不同分类任务。举例而言,多任务模型的分类任务包含分类任务1、分类任务2和分类任务3,则多任务模型的第一类网络中包含针对分类任务1的网络参数1、针对分类任务2的网络参数2以及针对分类任务3的网络参数3。而多任务模型的第二类网络中的网络参数是各分类任务共享的,即仅包含一份属于网络参数,该网络参数在针对各分类任务时均可以使用。
59.第一类网络包含全连接层和末端的特征提取层中的归一化层。而对于末端的特征提取层中的卷积层、激活函数层等网络层则是属于第二类网络。上述特征提取层是不同网络结构中所划分的结构层。以resnet50网络为例,resnet50网络中的特征提取层包括:layer1、layer2、layer3和layer4。layer1、layer2、layer3和layer4各自包含有多个卷积层、激活函数层和归一化层。在本公开实施例中,上述layer4即为末端的特征提取层,也可以称为最高层的特征提取层。
60.所获取的样本集可以是预先建好的。为了对多任务模型进行训练,所获取的样本集中需要包含多任务模型的每一分类任务的样本对象。举例而言,多任务模型的分类任务包含分类任务1和分类任务2,而样本集中包含有针对分类任务1的样本对象1和样本对象2,以及针对分类任务2的样本对象3和样本对象4,其中,样本对象1和样本对象2用于训练多分
类模型以实现分类任务1,样本对象3和4用于训练多分类任务以实现分类任务2。
61.需要说明的,多任务模型可以为对图像、音频等对象进行分类的模型。与多任务模型对应的,样本对象可以为图像、音频等对象。示例性的,若多任务模型为多任务图像分类模型,如基于cnn(convolutional neural network,卷积神经网络)所构建的模型,则样本对象可以为样本图像。
62.多任务模型的各分类任务为针对同一对象的分类任务,如针对图像进行分类的多个分类任务。各分类任务可以为并行无关联的任务,例如,分类任务1为对图像的颜色进行分类,分类任务2为对图像包含的对象进行分类;各分类任务也可以为粗细分类的任务,例如分类任务1为对图像包含的对象进行分类,分类任务2为对图像包含的建筑进行分类等。
63.s302,利用样本集所包括的各样本对象,以及各样本对象对应的任务标识,对多任务模型进行训练;其中,每一样本对象对应的任务标识为该样本对象所属分类任务的标识;在训练过程中,每一样本对象用于训练第二类网络的网络参数,以及第一类网络的针对相应任务的网络参数,相应任务为具有该样本对象对应的任务标识的分类任务。
64.其中,样本集中包含的样本对象对应有任务标识。每一任务标识表征该样本对象所属的分类任务。可选的,与每一样本对象与任务标识的对应关系可以基于样本对象的命名实现。举例而言,分类任务1的任务标识为1,则属于分类任务1的样本对象可以以为1
‑
xx进行命名,例如1
‑
01、1
‑
02等。可选的,上述任务标识可以为任务的任务id(identity document,身份证标识号)。
65.在获取到样本集后,即可利用样本集所包括的各样本对象,以及各样本对象对应的任务标识,对多任务模型进行训练。
66.在对多任务模型进行训练过程中,每一样本对象在被输入至多任务模型后,该样本对象仅用于训练第二类网络的网络参数,以及第一类网络的针对相应任务的网络参数,不用于训练第一类网络的非相应任务的网络参数。
67.举例而言,多任务模型的分类任务包含分类任务1和分类任务2,而样本集中包含有分类任务1的样本对象1和样本对象2,以及分类任务2的样本对象3和样本对象4。当将样本对象1输入至多任务分类模型时,样本对象1仅用于训练第二类网络的网络参数,以及第一类网络的针对分类任务1的网络参数,而不用于训练第一类网络的针对分类任务2的网络参数。
68.本公开提供的上述方案,由于因为第一类网络仅包含末端的特征提取层中的归一化层,故而第二类网络特征提取层中除归一化层外的网络参数是各分类任务所共享的,减少了每一分类任务独有的网络参数的数量。可见,通过本公开提供的上述方案可以减少多任务模型内需要训练的网络参数。同时,由于各分类任务在末端的特征提取层中具有独有的归一化层内的参数,从而可以避免各分类任务之间出现训练冲突的问题。
69.基于图3的实施例,如图4所示,本公开的另一实施例所提供的模型训练的方法,上述s302,可以包括步骤s3021
‑
s3024:
70.s3021,从样本集中,选取目标样本对象;
71.其中,从样本集中选取目标样本对象的方式可以为多种,例如,通过随机的方式从样本集中选取目标样本对象,或者预设的非随机的选取方式从样本集中选取目标样本对象,具体将在后续实施例进行介绍,在此不再赘述。
72.可选的,在一种实现方式中,为了充分利用到样本集中每一样本对象对目标分类模型进行训练,可以记录每次所获取的样本对象,从而可以在获取样本对象时,从样本集中获取未被利用过的样本对象,作为目标样本对象。
73.s3022,将选取的目标样本对象和对应的任务标识,输入多任务模型,以使多任务模型基于第二类网络的网络参数以及第一类网络的针对指定任务的网络参数,对目标样本对象进行分类,得到分类结果;其中,指定任务为具有目标样本对象对应的任务标识的分类任务;
74.其中,通过目标样本对象对应的任务标识可以确定目标样本对象所属的分类任务。因此,可以将目标样本对象和对应的任务标识,输入多任务模型。
75.而多任务模型在接收到目标样本对象及其对应的任务标识后,可以确定第一类网络中用于对目标样本进行处理的网络参数,即第一类网络的针对指定任务的网络参数,进而可以基于第二类网络的网络参数以及第一类网络的针对指定任务的网络参数,对目标样本对象进行分类,得到分类结果。
76.s3023,基于所得到的分类结果,与目标样本对象的标定结果的差异,调整第二类网络的网络参数以及第一类网络的指定任务的网络参数;
77.其中,在得到目标样本对象的分类结果之后,可以计算该分类结果与目标样本对象的标定结果的差异,该差异也称为模型损失。
78.举例而言,目标样本对象是针对分类任务1的样本对象,分类任务1用于将对象分为类别1、类别2和类别3。目标样本对象的标定结果为类别3。将目标样本对象输入至多任务模型,多任务模型在基于第二类网络的网络参数以及第一类网络的针对分类任务1的网络参数,对目标样本对象进行分类,得到分类结果为:类别1的概率20%、类别2的概率10%、类别3的概率70%,则可计算标定结果与分类结果的差异为:30%,或0.3。
79.在确定目标样本对象的标定结果的差异之后,可以基于所得到的分类结果,与目标样本对象的标定结果的差异,调整第二类网络的网络参数以及第一类网络的指定任务的网络参数。
80.其中,对于神经网络模型而言,差异越大,所要调整的参数的调整幅度也越大,因此,可以结合实际情况与需求,基于与目标样本对象的标定结果的差异,调整多任务模型中第二类网络的网络参数以及第一类网络的指定任务的网络参数。
81.可选的,在一种实现方式中,可以采用预定参数调整方式,对多任务模型中第二类网络的网络参数以及第一类网络的指定任务的网络参数进行调整。示例性的,预定参数调整方式可以随机梯度下降方式、批量梯度下降方式等等。
82.s3024,判断样本集中的样本对象是否均已被选取;若否,则返回执行步骤s3021;否则,训练结束。
83.可选的,在一种实现方式中,在对多任务模型中参数进行调整之后,还需要判断样本集中的样本对象是否均已被选取,若样本集中还存在未被选取过的样本对象,则还可以返回执行从样本集中,选取目标样本对象的步骤,即返回步骤s3021,直至样本集中所有的样本对象均已被选取。若样本集中所有的样本对象均已被选取,则训练结束。
84.本公开提供的上述方案,可以减少多任务模型内需要训练的网络参数,且可以避免各分类任务之间出现训练冲突的问题。进一步的,利用目标样本对象对多任务模型中第
二类网络的网络参数以及第一类网络的指定任务的网络参数进行调整,可以使训练后的多任务模型可以执行该指定任务。
85.可选的,针对多任务模型训练中需要采用不同样本对象进行调整的网络参数的情况,可以针对每一分类任务,先基于该分类任务的多个样本对象,对多任务模型中第二类网络的网络参数以及第一类网络的该分类任务的网络参数进行调整之后,再针对另一个分类任务的网络参数进行调整,避免多任务模型频繁的切换需要进行训练的网络参数。
86.基于此,本公开的另一实施例所提供的模型训练的方法,上述s3021,可以包括:
87.从样本集中,选取目标分类任务的多个目标样本对象;目标分类任务为多个分类任务中任一任务。
88.其中,针对任一分类任务,在将其确定为目标分类任务之后,可以先从样本集中选取目标分类任务的多个目标样本对象,进而将该多个目标样本对象依次输入多任务模型中,以对多任务模型中第二类网络的网络参数以及第一类网络的该目标分类任务的网络参数进行调整。
89.此时,在步骤s3023返回从样本集中,选取目标样本对象之前,还包括:
90.从多个分类任务中,确定新的目标分类任务。
91.其中,新的目标分类任务可以为随机确定的。可选的,在一种实现方式中,上述从多个分类任务中,确定新的目标分类任务,可以包括:
92.按照轮流选取任务的方式,从多个分类任务中,确定新的目标分类任务。
93.举例而言,多任务模型的分类任务包含:分类任务1、分类任务2和分类任务3,则先将分类任务1确定为目标分类任务,再将分类任务2确定为目标分类任务,最后将分类任务3确定为目标分类任务,一轮结束后,再次将分类任务1确定为目标分类任务,依次类推,直至完成训练。
94.通过轮流选取任务的方式,可以避免多任务模型中第二类网络中网络参数偏向于某一分类任务,从而可以提高多任务模型分类的准确性。
95.本公开提供的上述方案,可以减少多任务模型内需要训练的网络参数,且可以避免各分类任务之间出现训练冲突的问题。进一步的,还可以避免多任务模型频繁的切换需要进行训练的网络参数。
96.可选的,在一实施例中,第一类网络中还可以包括:设置在全连接层以及末端的特征提取层之间的全局池化层。末端的特征提取层将输出的特征向量输入至全局池化层,由全局池化层处理后再输入至全连接层。
97.可选的,在一实施例中,第一类网络中还可以包括:设置在全连接层以及末端的特征提取层之间的第一池化层和特征融合层;第二类网络中包括与末端的特征提取层相连接的第二池化层;特征融合层用于将第一池化层和第二池化层的特征进行融合后输入全连接层。
98.其中,第二池化层可以将底层的特征提取网络所提取的特征信息输入至特征融合层,而特征融合层可以将第一池化层和第二池化层的特征进行融合后输入全连接层。从而使得全连接层可以同时基于底层的特征提取网络所提取的特征信息,以及高层的特征提取网络所提取的特征信息进行对象分类。可以提升多任务模型的分类效果。
99.如图5所示,本公开实施例提供所一种对象分类方法,可以包括如下步骤:
100.s501,获取待分类的目标对象;
101.其中,在基于本发明所提供的模型训练的方法所训练得到的多任务模型,并部署完成之后,多任务模型可以实现针对对象的分类。此时,可以获取待分类的目标对象。
102.s502,基于预先训练的目标多任务模型,对目标对象进行多任务分类,得到各个分类任务的分类结果。
103.其中,当获取到目标对象后,可以基于预先训练的目标多任务模型,对目标对象进行多任务分类,得到各个分类任务的分类结果。
104.举例而言,目标多任务模型1的各分类任务包含分类任务1、分类任务2和分类任务3。则可以利用目标多任务模型1对目标对象进行多任务分类,得到分类任务1的分类结果、分类任务2的分类结果以及分类任务3的分类结果。
105.可选的,在一种实现方式中,可以将目标对象,输入目标多任务模型,以使目标多任务模型基于第二类网络的网络参数以及第一类网络的针对每一分类任务的网络参数,对目标对象进行多任务分类,得到各个分类任务的分类结果。
106.其中,目标多任务模型在获取到目标对象后,可以基于第二类网络的网络参数对目标对象进行对象处理,得到中间处理结果。再依次为中间处理结果设置各个分类任务的任务标识,并在每次设置任务标识后,利用第一类网络的、针对当前所设置的任务标识所对应分类任务的网络参数,对中间处理结果进行对象处理,得到当前所设置的任务标识所对应分类任务的分类结果。
107.上述对象处理可以为特征提取、特征池化、特征分类等处理。
108.上述基于第二类网络的网络参数对目标对象进行对象处理可以为利用底层特征提取层对目标对象进行特征提取,而上述中间分类结果可以为多任务模型中底层特征提取层所提取的底层特征。
109.在提取出底层特征后,既可为该底层特征设置各个分类任务中任意分类任务的任务标识。进而利用第一类网络的、针对当前所设置的任务标识所对应分类任务的网络参数,对底层特征进行特征提取、特征池化、特征分类等处理,得到当前所设置的任务标识所对应分类任务的分类结果。
110.在得到该分类结果之后,继续为所提取的底层特征设置下一分类任务的任务标识,重复上述过程,直至得到所有分类任务的分类结果。
111.本公开提供的上述方案,由于获取目标对象,可以对目标对象进行多任务分类,得到各个分类任务的分类结果,从而使得对目标对象进行分类的速度较快。
112.为了更好的理解本公开所提供的方案,如图6所示,以resnet50网络所构建用于进行图像分类的多任务模型的场景为例,介绍下本公开所提供的方案。该多任务模型中的第二类网络包括resnet50的stem layer,layer1,layer2、layer3,以及layer4中的conv和relu,第一类网络包括:layer4中bn、gap和fc层。
113.假定现在有4个图像分类任务,则4个图像分类任务共用resnet50的stem layer、layer1、layer2和layer3底层的特征提取层,以及在layer4中除了bn外的conv和relu。而在layer4中,每个图像分类任务拥有各自的网络参数(即bn层中的bn参数),并且每个图像分类任务各自的gap层内网络参数以及fc层内网络参数。使得各图像分类任务之间可以同时收敛,同时bn的参数量较少。使得训练和部署的过程中基本不占内存,同时给推理部署增加
很多的便利,并且还也可以支持拓展更多的图像分类任务。当需要增加图像分类任务时,只需要在bn层中增加一些与该图像分类任务对应的bn参数即可,具有很好的拓展性。并且,本方案同时也适用于其他带有bn结构的其他分类任务,具有很好的拓展性。
114.当需要对基于resnet50网络的多任务模型进行训练时,样本集中的样本的对象除具有图像路径和标定结果(如标签)外,还需要增加与样本对象对应的任务标识(如任务id)。这样当样本对象通过底层的特征提取层提取出底层特征后,在经过layer4时,可以根据样本对象对应的对象标识,来决定需要利用的bn层中的属于任务标识所指示图像分类任务的bn参数。比如图像分类任务0进行训练的时候,只使用layer4的bn层中bn1_0、bn2_0、bn3_0等以_0为后缀的bn参数,而layer4的bn层中属于其他图像分类任务的bn参数不使用。
115.当对目标图像进行分类的过程中,可以将目标图像输入至基于resnet50网络的多任务模型。经过stem layer、layer1、layer2、layer3输出14x14x1024维向量n。在经过layer4时,需要一次按照图像分类任务设置图像分类任务的任务标识,进而采用layer4的bn层中与任务标识对应的bn参数进行处理。在layer4处理完成之后,将layer4的处理结果经过对应的gap和fc层输出因为有4个图像分类任务,因此layer4需要依次设置4遍图像分类任务的任务标识,执行4遍上述过程,从而得到4个图像分类任务的分类结果。
116.进一步的,如上图6所示,还可以将layer2的输出特征经过gap得到512维特征,跟layer4的经过gap的2048维特征,拼接得到2560维特征。从而可以利用上layer2的底层特征,进而可以提升多任务的分类效果。
117.根据本公开的实施例,如图7所示,本公开还提供了一种模型训练的装置,上述装置包括:
118.样本集获取模块701,用于获取用于模型训练的样本集;其中,样本集包含多任务模型的每一分类任务的样本对象,多任务模型包含第一类网络和第二类网络,第一类网络包含全连接层,以及末端的特征提取层中的归一化层;第二类网络为除第一类网络以外的网络结构;
119.模型训练模块702,用于利用样本集所包括的各样本对象,以及各样本对象对应的任务标识,对多任务模型进行训练;
120.其中,每一样本对象对应的任务标识为该样本对象所属分类任务的标识;在训练过程中,每一样本对象用于训练第二类网络的网络参数,以及第一类网络的针对相应任务的网络参数,相应任务为具有该样本对象对应的任务标识的分类任务。
121.可选的,模型训练模块,包括:
122.对象选取子模块,用于从样本集中,选取目标样本对象;
123.对象分类子模块,用于将选取的目标样本对象和对应的任务标识,输入多任务模型,以使多任务模型基于第二类网络的网络参数以及第一类网络的针对指定任务的网络参数,对目标样本对象进行分类,得到分类结果;其中,指定任务为具有目标样本对象对应的任务标识的分类任务;
124.网络参数调整子模块,用于基于所得到的分类结果,与目标样本对象的标定结果的差异,调整第二类网络的网络参数以及第一类网络的指定任务的网络参数;并返回执行对象选取子模块,直至样本集中的样本对象均已被选取。
125.可选的,对象选取子模块,还用于从样本集中,选取目标分类任务的多个目标样本
对象;目标分类任务为多个分类任务中任一任务;
126.网络参数调整子模块,还包括:
127.任务确定子模块,用于在返回执行对象选取子模块之前,从多个分类任务中,确定新的目标分类任务。
128.可选的,任务确定子模块,还用于按照轮流选取任务的方式,从多个分类任务中,确定新的目标分类任务。
129.可选的,第一类网络中还包括:设置在全连接层以及末端的特征提取层之间的第一池化层和特征融合层;第二类网络中包括与末端的特征提取层相连接的第二池化层;特征融合层用于将第一池化层和第二池化层的特征进行融合后输入全连接层。
130.本公开提供的上述方案,由于在训练过程中,每一样本对象用于训练第二类网络的网络参数,以及第一类网络的针对具有该样本对象对应的任务标识的分类任务的网络参数,使得训练后的多任务模型中,第二类网络的网络参数是各分类任务所共享的,第一类网络中各网络参数是分属于各分类任务的。又因为第一类网络仅包含末端的特征提取层中的归一化层,故而第二类网络特征提取层中除归一化层外的网络参数是各分类任务所共享的,减少了每一分类任务独有的网络参数的数量。可见,通过本公开提供的上述方案可以减少多任务模型内需要训练的网络参数。
131.进一步的,由于各分类任务在末端的特征提取层中具有独有的归一化层内的参数,从而可以避免各分类任务之间出现训练冲突的问题。
132.根据本公开的对象分类方法的实施例,如图8所示,本公开还提供了一种对象分类装置,上述装置包括:
133.对象获取模块801,用于获取待分类的目标对象;
134.对象分类模块802,用于基于预先训练的目标多任务模型,对目标对象进行多任务分类,得到各个分类任务的分类结果;
135.其中,目标多任务分类模型为利用本发明所提供的模型训练的装置所训练得到的模型。
136.可选的,对象分类模块,包括:
137.对象输入子模块,用于将目标对象,输入目标多任务模型,以使目标多任务模型基于第二类网络的网络参数以及第一类网络的针对每一分类任务的网络参数,对目标对象进行多任务分类,得到各个分类任务的分类结果。
138.可选的,目标多任务模型,还用于基于第二类网络的网络参数对目标对象进行对象处理,得到中间处理结果;依次为中间处理结果设置各个分类任务的任务标识,并在每次设置任务标识后,利用第一类网络的、针对当前所设置的任务标识所对应分类任务的网络参数,对中间处理结果进行对象处理,得到当前所设置的任务标识所对应分类任务的分类结果。
139.本公开提供的上述方案,由于获取目标对象,可以对目标对象进行多任务分类,得到各个分类任务的分类结果,从而使得对目标对象进行分类的速度较快。
140.根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
141.本公开实施例提供了一种电子设备,包括:
142.至少一个处理器;以及
143.与至少一个处理器通信连接的存储器;其中,
144.存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行模型训练的方法或对象分类方法。
145.本公开实施例提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,计算机指令用于使计算机执行模型训练的方法或对象分类方法。
146.本公开实施例提供了一种计算机程序产品,包括计算机程序,计算机程序在被处理器执行时实现模型训练的方法或对象分类方法。
147.图9示出了可以用来实施本公开的实施例的示例电子设备900的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
148.如图9所示,设备900包括计算单元901,其可以根据存储在只读存储器(rom)902中的计算机程序或者从存储单元908加载到随机访问存储器(ram)903中的计算机程序,来执行各种适当的动作和处理。在ram 903中,还可存储设备900操作所需的各种程序和数据。计算单元901、rom 902以及ram 903通过总线904彼此相连。输入/输出(i/o)接口905也连接至总线904。
149.设备900中的多个部件连接至i/o接口905,包括:输入单元906,例如键盘、鼠标等;输出单元907,例如各种类型的显示器、扬声器等;存储单元908,例如磁盘、光盘等;以及通信单元909,例如网卡、调制解调器、无线通信收发机等。通信单元909允许设备900通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
150.计算单元901可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元901的一些示例包括但不限于中央处理单元(cpu)、图形处理单元(gpu)、各种专用的人工智能(ai)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(dsp)、以及任何适当的处理器、控制器、微控制器等。计算单元901执行上文所描述的各个方法和处理,例如模型训练的方法或对象分类方法。例如,在一些实施例中,模型训练的方法或对象分类方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元908。在一些实施例中,计算机程序的部分或者全部可以经由rom 902和/或通信单元909而被载入和/或安装到设备900上。当计算机程序加载到ram 903并由计算单元901执行时,可以执行上文描述的模型训练的方法或对象分类方法的一个或多个步骤。备选地,在其他实施例中,计算单元901可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行模型训练的方法或对象分类方法。
151.本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(fpga)、专用集成电路(asic)、专用标准产品(assp)、芯片上系统的系统(soc)、负载可编程逻辑设备(cpld)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器
可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
152.用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
153.在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦除可编程只读存储器(eprom或快闪存储器)、光纤、便捷式紧凑盘只读存储器(cd
‑
rom)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
154.为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,crt(阴极射线管)或者lcd(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
155.可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(lan)、广域网(wan)和互联网。
156.计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端
‑
服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
157.应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
158.上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
转载请注明原文地址:https://doc.8miu.com/read-1719092.html