PyTorch 实现 RetinaNet 目标检测

PyTorch入门实战教程

这篇文章介绍一个 PyTorch 实现的 RetinaNet 实现目标检测。文章的思想来自论文:Focal Loss for Dense Object Detection。这个实现的主要目标是为了方便读者能够很好的理解和更改源代码。

img3 img5

结果

当前的实现能达到 33.7% 的 mAP(600px 分辨率,Resnet-50)。论文里的结果是 34.0% mAP,造成这个差别的主要原因可能是这里使用了 Adam 优化器,而论文里使用了 SGD 和 weight decay。

安装

1. 用 Git 克隆 https://github.com/yhenon/pytorch-retinanet

2. 安装必备包:

3. 安装 Python 包:

4. 编译 NMS 扩展.

怎样训练

训练主要用 train.py 文件。现在可用的训练数据有两个: COCO 和 CSV。

要训练 COCO:

如果要训练自己的数据集,要用 CSV 格式:

预训练模型

可以在这里下载预训练模型:链接

项目地址

项目在 Github 上,点击访问

PyTorch入门实战教程

Leave a Reply

Your email address will not be published. Required fields are marked *

4条评论

  1. 在使用csv数据集时,因为给出的评价指标是每一类的mAP,那最后这33.7%是所有类求平均得来的吗?

  2. 在使用csv数据集时,因为给出的评价指标是每一类的mAP,那最后这33.7%是所有类求平均得来的吗?

返回顶部