1.本发明涉及人工智能和图像处理技术领域,尤其涉及一种基于中间层特征提取增强的知识蒸馏实现图像分类的方法。
背景技术:
2.随着人工智能领域中深度学习的发展,日益复杂的卷积神经网络模型在带来高性能的同时,其巨大的参数量和通道量却需要消耗大量的存储资源和计算资源,对模型在移动端和嵌入式设备中应用带来困难。因此,为了降低神经网络模型的存储占用空间和计算开销,典型的神经网络优化方法包括参数量化、紧凑模型、剪枝和共享、低秩分解和知识蒸馏等。
3.知识蒸馏是对深度模型进行优化的一种常用的方法。通过定义合适的蒸馏损失,知识蒸馏往往能够超越传统的基于真实标签的监督训练方法,实现模型推理准确率的提升。知识蒸馏的训练框架中通常包含一个或多个教师模型,以及一个学生模型。当学生模型的大小小于教师模型时,其本质上就实现了一个高效的模型压缩过程。对比其他神经网络优化方法,知识蒸馏的优点在于无需对模型结构进行复杂的修改,实施过程相对简单,可推广性好,模型稳定性好。通过知识蒸馏获得的轻量化模型具有确定的压缩比和准确率下限,确保了在特定的场景中或设备上的可部署性。
4.然而,最早的知识蒸馏方法只考虑对教师模型的分对数输出,而忽视了模型中间层所包含的丰富的暗知识。因此,近年来出现了一些基于中间层特征的知识蒸馏方法。
5.现有方法中,基于中间层的知识蒸馏在图像分类中主要存在以下问题:(1)缺乏对多尺度像素间关系的表征;(2)存在背景噪声干扰:对于图像分类任务而言,一张中间层特征图像中只有少量的关键特征决定了图像的分类结果,剩余的大部分像素对于得出分类结果毫无贡献,因此对所有像素不加筛选地进行蒸馏,不仅降低了蒸馏学习的效率,甚至由于在蒸馏过程中拟合了大量背景噪声而有害于学生模型的训练。
技术实现要素:
6.本发明实施例所要解决的技术问题在于,提供一种基于中间层特征提取增强的知识蒸馏实现图像分类的方法,通过增强知识蒸馏对中间层特征提取能力来改善卷积神经网络优化效果,用以增强图像分类效果,从而解决了现有技术中所存在的缺乏对多尺度像素间关系的表征以及背景噪声干扰的问题。
7.为了解决上述技术问题,本发明实施例提供了一种基于中间层特征提取增强的知识蒸馏实现图像分类的方法,所述方法包括以下步骤:
8.获取待分类图像;
9.将所述待分类图像导入预先训练好的教师
‑
学生网络中,得到相应的分类结果;其中,所述预先训练好的教师
‑
学生网络是基于历史图像分别输入教师模型和学生模型中,并采用预设的跨层非局部模块分别提取学生模型和教师模型的多尺度像素间关系,且待计算
出教师模型和学生模型间的多尺度像素关系蒸馏损失之后,将蒸馏损失加入学生模型的损失函数中,进一步根据损失函数反向传播更新学生模型参数直至学生模型收敛,将收敛后的学生模型作为优化模型输出进行训练得到的。
10.其中,所述跨层非局部模块采用如下公式进行计算:
11.r=(x
q
,x
r1
,
…
,x
rn
)=x
q
∑z
ri
12.其中,x
q
为查询层特征;x
ri
为响应层特征;z
ri
为响应层i与查询层的像素间关系,表示为z为卷积运算;θ(
·
),φ(
·
)和g(
·
)均为可学习嵌入式函数,使用1
×
1卷积实现;θ(x
q
),gi(x
ri
)为可学习嵌入函数对输入的特征图做预处理,计算单个像素的表示;f(
·
,
·
)为二维函数,使用点积实现;为计算对应位置像素间的相关程度。
13.其中,所述跨层非局部模块提取学生模型或教师模型的多尺度像素间关系的具体步骤如下:
14.将历史图像作为学生模型或教师模型的输入,并输入相应模型的第一层;
15.若第一层是选定的响应层,将第一层的输出特征作为响应层输入其后的跨层非局部模块,并将跨层非局部模块的输出特征输入其后的第二层;或若第一层是选定的查询层,将第一层的输出特征作为查询层输入其后的跨层非局部模块;
16.用第二层更新第一层;
17.若第一层是最后一层,将最后一层的输出特征作为预测结果并输出。
18.其中,计算教师模型和学生模型间的多尺度像素关系蒸馏损失时,采用l2范式损失的形式如下:
19.l
蒸馏
=l2(r
t
,m(r
s
))
20.其中,m(r
s
)是可学习的匹配函数,使教师模型和学生模型的多尺度关系特征图在维度和尺寸上匹配;r
s
首先通过一个卷积层c(
·
),然后再通过一个上采样函数h(
·
)进行匹配,即m(r
s
)=h(c(r
s
))。
21.其中,将蒸馏损失加入到学生模型的损失函数中时,采用如下公式进行运算:
22.l
总
=l
分类
αl
蒸馏
23.其中,l
总
为总损失函数;l
分类
为分类损失函数;l
蒸馏
为蒸馏损失函数;α为蒸馏损失函数在总损失函数中占的比例系数。
24.其中,所述分类损失函数采用交叉熵形式计算,具体公式如下:
[0025][0026]
其中,y为图像的真实分类标签,为学生模型输出的预测结果。
[0027]
实施本发明实施例,具有如下有益效果:
[0028]
本发明针对现有基于中间层的知识蒸馏方法中缺乏对多尺度像素间关系表征的问题,在教师
‑
学生网络的中间层插入一个可学习的跨层非局部模块,提取教师模型和学生模型的多尺度像素间关系,通过知识蒸馏的方式使学生模型拟合教师模型的多尺度像素间关系,改善学生模型的中间层输出,提升学生模型的特征提取能力,有效地提升知识蒸馏对教师
‑
学生网络优化效果,增强了图像分类效果,从而解决了现有技术中所存在的缺乏对多
尺度像素间关系的表征以及背景噪声干扰的问题。
附图说明
[0029]
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,根据这些附图获得其他的附图仍属于本发明的范畴。
[0030]
图1为本发明实施例提供的基于中间层特征提取增强的知识蒸馏实现图像分类的方法的流程图;
[0031]
图2为本发明实施例提供的基于中间层特征提取增强的知识蒸馏实现图像分类的方法中训练教师
‑
学生网络的流程图;
[0032]
图3为本发明实施例提供的基于中间层特征提取增强的知识蒸馏实现图像分类的方法中跨层非局部模块的结构示意图;
[0033]
图4为本发明实施例提供的基于中间层特征提取增强的知识蒸馏实现图像分类的方法中使用跨层非局部模块进行多尺度像素关系提取的流程图。
具体实施方式
[0034]
为使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明作进一步地详细描述。
[0035]
如图1所示,为本发明实施例中,提出的一种基于中间层特征提取增强的知识蒸馏实现图像分类的方法,所述方法包括以下步骤:
[0036]
步骤s1、获取待分类图像;
[0037]
步骤s2、将所述待分类图像导入预先训练好的教师
‑
学生网络中,得到相应的分类结果;其中,所述预先训练好的教师
‑
学生网络是基于历史图像分别输入教师模型和学生模型中,并采用预设的跨层非局部模块分别提取学生模型和教师模型的多尺度像素间关系,且待计算出教师模型和学生模型间的多尺度像素关系蒸馏损失之后,将蒸馏损失加入学生模型的损失函数中,进一步根据损失函数反向传播更新学生模型参数直至学生模型收敛,将收敛后的学生模型作为优化模型输出进行训练得到的。
[0038]
具体过程为,在步骤s1之前,预先训练教师
‑
学生网络,其训练过程如图2所示,具体包括:
[0039]
s201:接收历史图像,作为卷积神经网络模型的输入分别输入教师模型和学生模型;
[0040]
s202:教师模型通过跨层非局部模块(如图3所示)提取多尺度像素间关系之后,执行步骤s204;
[0041]
s203:学生模型通过跨层非局部模块(如图3所示)提取多尺度像素间关系之后,执行步骤s204;
[0042]
s204:计算多尺度像素关系蒸馏损失。计算教师模型和学生模型间的多尺度像素关系蒸馏损失时,采用l2范式损失的形式
[0043]
l
蒸馏
=l2(r
t
,m(r
s
))
[0044]
其中,m(r
s
)是可学习的匹配函数,使教师模型和学生模型的多尺度关系特征图在维度和尺寸上匹配。在本实施例中,r
s
首先通过一个卷积层c(
·
),然后再通过一个上采样函数h(
·
)进行匹配,即m(r
s
)=h(c(r
s
))
[0045]
s205:将蒸馏损失加入到学生模型的损失函数中。采用如下公式求加权和得到总损失函数:
[0046]
l
总
=l
分类
αl
蒸馏
[0047]
其中,l
总
为总损失函数;l
分类
为分类损失函数,该函数采用交叉熵形式计算,具体公式为y为图像的真实分类标签,为学生模型输出的预测结果;l
蒸馏
为蒸馏损失函数;α为蒸馏损失函数在总损失函数中占的比例系数;
[0048]
s206:判断学生模型的准确率是否收敛,若是,则执行s207;否则,更新学生模型参数,且待参数更新完成之后,返回步骤s203;
[0049]
s207:将学生模型作为优化模型输出,即得到训练好的教师
‑
学生网络。
[0050]
在本发明实施例中,跨层非局部模块是为了提取查询层特征x
q
和响应层特征x
ri
之间的多尺度像素间关系,首先使用可学习的嵌入式函数θ(
·
)和φ(
·
)分别对x
q
和x
ri
进行预处理,投影到新的特征空间中。然后利用特征空间中的函数f(
·
,
·
)处理θ(x
q
)和φ(x
ri
),再经过一个归一化指数函数softmax层计算注意力映射。同时,响应层的位置特征由另一个可学习的嵌入函数g(
·
)投影得到。在输出端,利用卷积z保证输出的多尺度像素关系特征图与查询层特征图x
q
维度和尺寸上的一致性,以便于后续的叠加。因此,跨层非局部模块输出的多尺度像素关系特征图是多尺度像素关系和查询层特征图的叠加,该跨层非局部模块采用如下公式(1)进行计算:
[0051]
r=(x
q
,x
r1
,
…
,x
rn
)=x
q
∑z
ri
ꢀꢀ
(1)
[0052]
其中,x
q
为查询层特征;x
ri
为响应层特征;z
ri
为响应层i与查询层的像素间关系,表示为z为卷积运算;θ(
·
),φ(
·
)和g(
·
)均为可学习嵌入式函数,使用1
×
1卷积实现;θ(x
q
),gi(x
ri
)为可学习嵌入函数对输入的特征图做预处理,计算单个像素的表示;f(
·
,
·
)为二维函数,使用点积实现;为计算对应位置像素间的相关程度,即f(θ(x
q
),φ(x
ri
))=θ(x
q
)
t
φ(x
ri
)。
[0053]
在本发明实施例中,跨层非局部模块跨层非局部模块提取学生模型或教师模型的多尺度像素间关系的具体步骤如图4所示,具体包括:
[0054]
s401:将历史图像作为学生模型或教师模型的输入,并输入相应模型的第一层;
[0055]
s402:若第一层是选定的响应层,将第一层的输出特征作为响应层输入其后的跨层非局部模块,并将跨层非局部模块的输出特征输入其后的第二层;或若第一层是选定的查询层,将第一层的输出特征作为查询层输入其后的跨层非局部模块;其中,选择模型的倒数第二层作为响应层,倒数第三,四,五,六,七层作为查询层;
[0056]
s403:用第二层更新第一层;
[0057]
s404:若第一层是最后一层,将最后一层的输出特征作为预测结果并输出,即将跨层非局部模块的输出作为提取的多尺度像素关系特征用于知识蒸馏。
[0058]
实施本发明实施例,具有如下有益效果:
[0059]
本发明针对现有基于中间层的知识蒸馏方法中缺乏对多尺度像素间关系表征的问题,在教师
‑
学生网络的中间层插入一个可学习的跨层非局部模块,提取教师模型和学生模型的多尺度像素间关系,通过知识蒸馏的方式使学生模型拟合教师模型的多尺度像素间关系,改善学生模型的中间层输出,提升学生模型的特征提取能力,有效地提升知识蒸馏对教师
‑
学生网络优化效果,增强了图像分类效果,从而解决了现有技术中所存在的缺乏对多尺度像素间关系的表征以及背景噪声干扰的问题。
[0060]
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,所述的程序可以存储于一计算机可读取存储介质中,所述的存储介质,如rom/ram、磁盘、光盘等。
[0061]
以上所揭露的仅为本发明一种较佳实施例而已,当然不能以此来限定本发明之权利范围,因此依本发明权利要求所作的等同变化,仍属本发明所涵盖的范围。
转载请注明原文地址:https://doc.8miu.com/index.php/read-1722542.html