您的位置 首页 PyTorch 教程

PyTorch 实现 U-Net 用于图像分割

PyTorch入门实战教程

图像分割(Image Segmentation)是图像领域里非常重要的一个问题,它将图像分割成不同大的部分,每个部分代表不同的区域(如下图)。

U-net 最初是用在医学图像领域,但其性能和结果都很好,也被用在了其它很多领域里。

U-net网络结构

整个架构看来像个 U,也就是为什么被称为 U-net。整个网络可以分为三部分:contraction、bottleneck,和 expansion,分别对用上图中的左边部分,下边部分,和右边部分。

Contraction 部分是由很多 contraction block 组成,每个 block 对输入做 3×3 大的卷积,然后是 2×2 的最大池化(max pooling)。每个 block 输出的特征图数量是上一个 block 的两倍,可以保证网络高效地学习复杂的图像特征。Bottleneck 部分则是两个 3×3 的 CNN 加上 2×2 的 Up Convolutional 层。对于 Expansion 部分,则包含很多 expansion block,每个 block 的输入上做 3×3 的 CNN 加上 2×2 的上采样(up sampling),并且每个 block 的特征图数量减半。最重要的是,每个 block 的输入都要与左边 contraction 的对用部分的输出合并。

损失函数

原论文中是这样描述损失函数的:

The energy function is computed by a pixel-wise soft-max over the final feature map combined with the cross-entropy loss function.

简单来讲,就是对与每个像素,应用 Softmax,然后用交叉熵损失函数(Cross Entropy),这样相当于将每个像素分为一类。

PyTorch 实现

下面的代码是我们的实现:

训练部分:

参考资料

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(4)

返回顶部