基于变分自编码器与高斯混合模型的无监督目标检测方法及系统与流程

专利2022-05-09  16


本发明涉及的是一种人工智能领域的技术,具体是一种基于变分自编码器与高斯混合模型的无监督目标检测方法及系统。



背景技术:

目前的有监督学习仍需要大量经过标记的数据集,而这种处理需要大量的工作量,从而使得有用的数据集变得难以获取。同时,有监督学习得到的深度学习模型的泛用性较差,在不同的数据集上的性能可能会衰减。相比之下,无监督学习的最大特点就是不需要对数据进行标记,这大大减少了工作量。同时,无监督学习致力于得到通用的模型,提高模型的泛用性。而在无监督学习领域中,变分自编码器是一类非常重要的框架,虽然目前该框架基础上提出的air(attend,infer,repeat)、spair(spatiallyinvariantattend,infer,repeat)等模型具有一定效果,但仍存在只能在少量物体的场景下工作的弊端。



技术实现要素:

本发明针对现有无监督学习的目标检测对于分类与多物体场景下检测精度不足的问题以及现有基于变分自编码器的目标检测框架难以应对有大量物体的场景和对物体种类信息不敏感的缺陷,提出一种基于变分自编码器与高斯混合模型的无监督目标检测方法和系统,结合了空间注意力机制和高斯混合模型,不仅能够实现端到端的目标检测与分类,同时在存在大量物体的情况下仍有较好的性能,具有较好的扩展性。

本发明是通过以下技术方案实现的:

本发明涉及一种基于变分自编码器与高斯混合模型的无监督目标检测方法,通过骨干网络将输入图像转化一个h*w维度,即h*w个单元格的特征图,再将该特征图编码为先验分布符合高斯混合模型的隐变量,然后由解码器根据隐变量进行图像重构,并将重构的图像与输入图像进行比较并计算损失函数,从而训练神经网络,编码器得到图像中物体的类别与位置等信息,从而实现无监督目标检测。

所述的骨干网络是指:由深度残差网络连接反卷积层形成的提取图像特征的神经网络。

所述的隐变量包括:每个单元格中物体的类别、特征、位置、深度、出现概率信息。

所述的转化,具体过程为:

,其中:二元变量zpres用于表示物体是否存在,zpres=1表示物体存在于图像中;zwhere可以被分解为(zy,zx,zh,zw),zy与zx是该单元格中所包含物体的中心位置,zw与zh是该物体的宽和高;zdepth为物体在图像中的深度信息,用来处理物体堆叠的问题,zcat是一个为物体类别的c维的一位有效编码,c是图像集中物体总的类别数;zwhat是一个分布符合高斯混合模型的物体特征向量的a维编码,具体为:是高斯分布的概率密度函数,μk与是第k个类别的高斯分布的期望与方差,μk与作为可学习的参数。

在重构图像中,深度信息更低的物体会出现在更高的物体上面。

技术效果

本发明整体解决了现有无监督目标检测技术的不能同时完成物体定位于聚类的缺陷;与现有技术相比,本发明结合空间注意力机制和高斯混合模型,使得目标定位与聚类两个任务可并行完成,并使得两个任务均有较好的效果。同时在场景中存在大量物体的情况下仍有较好的性能。

附图说明

图1为本发明网络结构示意图。

具体实施方式

如图1所示,为本实施例涉及一种基于变分自编码器与高斯混合模型的无监督目标检测系统,包括:骨干网络以及分别与之相连的pres-预测头、depth预测头和where-预测头;空间转换网络以及分别与之相连的what-编码器和cat-编码器;单元格解码器、what先验网络以及可微分渲染器,其中:骨干网络对输入图像进行预处理得到特征图,pres-预测头、depth预测头和where-预测头分别根据特征图得到zpres隐变量、zwhere隐变量和zdepth隐变量,空间转换网络根据输入图像和zwhere隐变量进行空间变换处理并分别输出单元格信息至cat-编码器和what-编码器,cat-编码器根据单元格信息得到zcat隐变量,what-编码器根据每个单元格与其对应的zcat隐变量拼接后得到zwhat隐变量,单元格解码器根据zwhat隐变量重新生成单元格,可微分渲染器根据zpres隐变量、zwhere隐变量、zdepth隐变量和重新生成的单元格,经渲染并输出重构图像,what先验网络根据zcat隐变量生成zwhat隐变量的先验分布。

所有的神经网络参数通过最小化损失函数进行训练;训练后的pres-预测头输出的zpres隐变量即代表每个单元格是否存在物体,训练后的where-预测头输出的zwhere隐变量即代表存在物体时每个单元格中物体的具体位置;训练后的cat-预测头输出的zcat隐变量即为存在物体时每个单元格中物体的类别,从而完成在统一的网络结构中,对图像中的物体进行无监督,即用于训练数据集中的数据不需要人工标注的目标检测,获得图片中所关注物体的位置与类别信息。

所述的损失函数为重构损失与正则化损失之和,其中:重构损失为原图与重构图像间的二值交叉熵,正则化损失为每个隐变量的分布q(z*|x)与其预设的先验分布p(z*)的kl散度。其中zwhat隐变量的先验分布由what先验网络生成。

所述的二值交叉熵描述原图与重构图像间的差异,其越小表示原图与重构图像间差异越小。

本实施例中的pres-预测头、depth-预测头、where-预测头均为四层卷积神经网络,其具体网络结构参数如表1所示。

表1pres-检测头、depth-检测头与where-检测头的结构参数

所述的空间转换网络内含一个空间变换处理模块,该空间变换处理模块读取where-预测头输出的zwhere隐变量并得到每个单元格中表征物体具体位置的一个矩形框的预判后,通过把该矩形框平移到原点,并放缩到固定大小32×32得到各单元格信息。

所述的what-编码器与cat-编码器为多层感知机,其具体结构参数如表2所示。

表2what-编码器与cat-编码器的结构参数

所述的单元格解码器为深度卷积神经网络,其具体结构参数如表3所示。

表3单元格解码器的结构参数

所述的what先验网络为两个独立的单层感知机,大小各为10×256。

所述的可微分渲染器通过表征位置信息的zwhere隐变量将重新生成的单元格还原到其预判的位置;然后设置单元格中每个像素的值为所有覆盖它的单元格对应值,即zpres隐变量与其深度,即zdepth隐变量的加权平均。

经过具体实际实验,以表4所示参数设置损失函数进行网络训练。在multimnist数据集下,定位平均准确度为97.3±0.10,聚类准确度为80.4±0.48,聚类归一化互信息指标(nmi)为75.5±0.66。在fruit2d数据集下定位平均准确度为84.9±1.56,聚类准确度为90.9±0.32,聚类nmi为85.7±1.25,x±y中,x与y分别代表取多次随机因子下,实验结果的平均值与标准差。

表4先验分布参数设置

与现有技术相比,本系统在不损失定位精度与聚类准确度的前提下,同时完成无监督目标检测的定位与聚类任务。

上述具体实施可由本领域技术人员在不背离本发明原理和宗旨的前提下以不同的方式对其进行局部调整,本发明的保护范围以权利要求书为准且不由上述具体实施所限,在其范围内的各个实现方案均受本发明之约束。

转载请注明原文地址:https://doc.8miu.com/read-250225.html

最新回复(0)