简介
densenet网络是CVPR 2017 (Best Paper Award),这篇论文是在Stochastic Depth的启发下提出的。densenet和Stochastic Depth都是清华的黄高博士提出的。关于densenet的详细介绍可以看【0】,关于Stochastic Depth的详细介绍可以看【1】。
论文链接:https://arxiv.org/pdf/1608.06993.pdf
代码的github链接:https://github.com/liuzhuang13/DenseNet
DenseNet(密集卷积网络)的核心思想是密集连接,即某层的输入除了包含前一层的输出外还包含前面所有层的输出,因此L层的模型会有L(L-1)/2个连接,具体结构图如下图所示,DenseNet的几个优点是:
1,减轻了消失梯度(梯度消失)
2,加强了特征的传递
3,更有效地利用了特征
4,一定程度上较少了参数数量
在保证网络中层与层之间最大程度的信息传输的前提下,直接将所有层连接起来!
模型细节
整个densenet模型主要包含三个核心细节结构,分别是DenseLayer(整个模型最基础的原子单元,完成一次最基础的特征提取,如下图第三行)、DenseBlock(整个模型密集连接的基础单元,如下图第二行左侧部分)和Transition(不同密集连接之间的过度单元,如下图第二行右侧部分),通过以上结构的拼接+分类层即可完成整个模型的搭建。
DenseLayer层包含BN + Relu + 1*1Conv + BN + Relu + 3*3Conv。第L个DenseLayer层的第一个1*1Conv的输入通道层数为num_input_features+(L-1)*growth_rate,输出通道层数为bn_size*growth_rate;第二个3*3Conv的输入通道数为bn_size*growth_rate,输出通道数为growth_rate。整个DenseLayer层内特征层宽度不变,不存在stride=2或者池化的情况。这里有一点特殊之处,DenseLayer层的第一个结构是BN层而不是像其它模型那样是Conv。在BN层前面还存在一个Concatenation操作,负责本DenseBlock模块内前面所有层的输出以及第一层的输出进行拼接操作,
DenseBlock模块其实就是堆叠一定数量的DenseLayer层,在整个DenseBlock模块内不同DenseLayer层之间会发生密集连接,在DenseBlock模块内特征层宽度不变,不存在stride=2或者池化的情况。
Transition模块包含BN + Relu + 1*1Conv + 2*2AvgPool,1*1Conv负责降低通道数,2*2AvgPool负责降低特征层宽度,降低到1/2。Transition模块的作用是连接不同的DenseBlock模块,之所以这样设计原因是,密接连接必须保证特征层的宽度是一致的,原因是连接方式为沿通道维拼接,如果整个模型都采用密集连接,那势必导致整个模型从输入到输出特征层宽度都不变,那最后无法完成分类任务,也无法压缩特征。
模型可能优点
更强的梯度流动:
DenseNet可以说是一种隐式的强监督模式,因为每一层都建立起了与前面层的连接,误差信号可以很容易地传播到较早的层,所以较早的层可以从最终分类层获得直接监管。
参数更少计算效率更高
在ResNet中,参数量与C*C成正比,而在DenseNet中参数量与l*k*k成正比,因为k远小于C,所以DenseNet的参数量小得多。
保存了低维度的特征
在标准的卷积网络中,最终输出只会利用提取最高层次的特征。而在DenseNet中,它使用了不同层次的特征,它倾向于给出更平滑的决策边界。这也解释了为什么训练数据不足时DenseNet表现依旧良好。
模型效果
该文章提出的DenseNet核心思想在于建立了不同层之间的连接关系,充分利用了功能,进一步减轻了梯度消失问题,加深网络不是问题,而且训练效果非常好。另外,利用瓶颈层,翻译层以及较小的增长率使得网络变窄,参数减少,有效抑制了过拟合,同时计算量也减少了DenseNet优点很多,而且在和RESNET的对比中优势还是非常明显的。【2】
模型代码
改代码修改自torch官方代码
# 根据torch官方代码修改的densenet代码
# 模型下载地址:
# 121 --- "), 224)
image = transform(image)
image = image.reshape(1, 3, 224, 224)
# 建立模型并恢复权重
weight_path = "./checkpoint/densenet121-a639ec97.pth" # 这个预训练权重是老版本torch生成的,当时模块的命名允许出现"."
pre_weights = torch.load(weight_path) # 但是最新的torch不允许出现".",所以老版权重恢复进新版模型时需要修改一下模块命名
pattern = re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$")
for key in list(pre_weights.keys()): # 主要是新版模型中的最基础模块的命名是类似于...denselayer1.conv1.weight
res = pattern.match(key) # 而老版本权重的命名类似于 ...denselayer1.conv.1.weight
if res: # 所以需要正则表达式去老版本权重的key中匹配一下,一旦匹配成功就修改为最新模型的权重名称
new_key = res.group(1) + res.group(2) # 正则表达式中()的作用是提取满足匹配要求的字符串,group(0)就是匹配正则表达式整体结果
pre_weights[new_key] = pre_weights[key]
del pre_weights[key]
model = densenet121()
model.load_state_dict(pre_weights)
# print(model)
# 单张图片推理
model.cpu().eval() # .eval()用于通知BN层和dropout层,采用推理模式而不是训练模式
with torch.no_grad(): # torch.no_grad()用于整体修改模型中每一层的requires_grad属性,使得所有可训练参数不能修改,且正向计算时不保存中间过程,以节省内存
output = torch.squeeze(model(image))
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
# 输出结果
print(predict_cla)
print(predict[predict_cla])
包含训练和测试的完整代码见:https://github.com/LegendBIT/torch-classification-model
参考:
0. 深入解析DenseNet(含大量可视化及计算)
1. CNN模型合集 | 9 Stochastic_Depth
2. DenseNet算法详解