ginka-generator/ginka/model/unet.py

92 lines
2.9 KiB
Python

import torch
import torch.nn as nn
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x_res = self.conv(x) # 卷积提取特征
x_down = self.pool(x_res) # 进行池化
return x_down, x_res # 返回池化后的特征和跳跃连接特征
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 上采样(双线性插值 + 卷积)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
# 跳跃连接融合
self.conv = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, skip):
x = self.upsample(x)
# 跳跃连接融合
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class GinkaUNet(nn.Module):
def __init__(self, in_ch=64, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.down1 = GinkaEncoder(in_ch, in_ch*2)
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
self.up2 = GinkaDecoder(in_ch*2, in_ch)
self.final = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1)
)
def forward(self, x):
x, skip1 = self.down1(x)
x, skip2 = self.down2(x)
x = self.bottleneck(x)
x = self.up1(x, skip2)
x = self.up2(x, skip1)
return self.final(x)