这个项目用 PyTorch 实现了基本的 Batchnorm Fusion 方法,可以应用于大部分的流行 CNN 架构。这个项目的主要目的是为了提升神经网络在测试时的推理速度(inference time)。理想提速值可以高达 30%!
原理
我们知道通常情况下卷积和 Batchnorm 都是关于数据 x 的线性操作,可以被写成矩阵相乘的形式:
$$T_{bn} * S_{bn} * W_{conv} * x$$
Batch Norm Fusion 先给数据做卷积操作,然后再应用 batchnorm。
支持的架构
这个项目支持 Conv 和 BN 一起的任意连续模型。为了方便,他们提供了 VGG,ResNet,和 SeNet 的示例。
- VGG from torchvision
- ResNet Family from
torchvision
- SeNet family from
pretrainedmodels
使用示例
1 2 3 4 5 6 7 | import torchvision.models as models from bn_fusion import fuse_bn_recursively net = getattr(models,\'vgg16_bn\')(pretrained=True) net = fuse_bn_recursively(net) net.eval() # Make inference with the converted model |
项目地址
项目托管在 Github 上,地址是这里。
本站微信群、QQ群(三群号 726282629):