图像分割(Image Segmentation)是图像领域里非常重要的一个问题,它将图像分割成不同大的部分,每个部分代表不同的区域(如下图)。
U-net 最初是用在医学图像领域,但其性能和结果都很好,也被用在了其它很多领域里。
U-net网络结构
整个架构看来像个 U,也就是为什么被称为 U-net。整个网络可以分为三部分:contraction、bottleneck,和 expansion,分别对用上图中的左边部分,下边部分,和右边部分。
Contraction 部分是由很多 contraction block 组成,每个 block 对输入做 3×3 大的卷积,然后是 2×2 的最大池化(max pooling)。每个 block 输出的特征图数量是上一个 block 的两倍,可以保证网络高效地学习复杂的图像特征。Bottleneck 部分则是两个 3×3 的 CNN 加上 2×2 的 Up Convolutional 层。对于 Expansion 部分,则包含很多 expansion block,每个 block 的输入上做 3×3 的 CNN 加上 2×2 的上采样(up sampling),并且每个 block 的特征图数量减半。最重要的是,每个 block 的输入都要与左边 contraction 的对用部分的输出合并。
损失函数
原论文中是这样描述损失函数的:
The energy function is computed by a pixel-wise soft-max over the final feature map combined with the cross-entropy loss function.
简单来讲,就是对与每个像素,应用 Softmax,然后用交叉熵损失函数(Cross Entropy),这样相当于将每个像素分为一类。
PyTorch 实现
下面的代码是我们的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | import torch from torch import nn import torch.nn.functional as F import torch.optim as optim class UNet(nn.Module): def contracting_block(self, in_channels, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) ) return block def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block def __init__(self, in_channel, out_channel): super(UNet, self).__init__() #Encode self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64) self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode2 = self.contracting_block(64, 128) self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode3 = self.contracting_block(128, 256) self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2) # Bottleneck self.bottleneck = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) ) # Decode self.conv_decode3 = self.expansive_block(512, 256, 128) self.conv_decode2 = self.expansive_block(256, 128, 64) self.final_layer = self.final_block(128, 64, out_channel) def crop_and_concat(self, upsampled, bypass, crop=False): if crop: c = (bypass.size()[2] - upsampled.size()[2]) // 2 bypass = F.pad(bypass, (-c, -c, -c, -c)) return torch.cat((upsampled, bypass), 1) def forward(self, x): # Encode encode_block1 = self.conv_encode1(x) encode_pool1 = self.conv_maxpool1(encode_block1) encode_block2 = self.conv_encode2(encode_pool1) encode_pool2 = self.conv_maxpool2(encode_block2) encode_block3 = self.conv_encode3(encode_pool2) encode_pool3 = self.conv_maxpool3(encode_block3) # Bottleneck bottleneck1 = self.bottleneck(encode_pool3) # Decode decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True) cat_layer2 = self.conv_decode3(decode_block3) decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True) cat_layer1 = self.conv_decode2(decode_block2) decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True) final_layer = self.final_layer(decode_block1) return final_layer |
训练部分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | unet = Unet(in_channel=1,out_channel=2) #out_channel represents number of segments desired criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99) optimizer.zero_grad() outputs = unet(inputs) # permute such that number of desired segments would be on 4th dimension outputs = outputs.permute(0, 2, 3, 1) m = outputs.shape[0] # Resizing the outputs and label to caculate pixel wise softmax loss outputs = outputs.resize(m*width_out*height_out, 2) labels = labels.resize(m*width_out*height_out) loss = criterion(outputs, labels) loss.backward() optimizer.step() |
参考资料
本站微信群、QQ群(三群号 726282629):

这个代码对吗 感觉网络少了最下面的一部分啊
U-net网络结构那个图使用什么画的,求告知!谢谢!
powerpoint 就可以吧
请问数据集在哪下载