mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 06:51:11 +08:00
107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from shared.attention import CBAM, SEBlock
|
|
|
|
class GinkaEncoder(nn.Module):
|
|
"""编码器(下采样)部分"""
|
|
def __init__(self, in_channels, out_channels, attention=False, block='CBAM'):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(out_channels),
|
|
nn.GELU(),
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(out_channels),
|
|
)
|
|
# 注意力
|
|
if attention:
|
|
if block == 'CBAM':
|
|
self.conv.append(CBAM(out_channels))
|
|
elif block == 'SEBlock':
|
|
self.conv.append(SEBlock(out_channels))
|
|
self.conv.append(nn.GELU())
|
|
self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
x_res = self.conv(x)
|
|
x_down = self.down(x_res)
|
|
return x_down, x_res
|
|
|
|
class GinkaDecoder(nn.Module):
|
|
"""解码器(上采样)部分"""
|
|
def __init__(self, in_channels, out_channels, attention=False, block='CBAM'):
|
|
super().__init__()
|
|
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
|
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(out_channels),
|
|
)
|
|
# 注意力
|
|
if attention:
|
|
if block == 'CBAM':
|
|
self.conv.append(CBAM(out_channels))
|
|
elif block == 'SEBlock':
|
|
self.conv.append(SEBlock(out_channels))
|
|
self.conv.append(nn.GELU())
|
|
|
|
def forward(self, x, skip):
|
|
x = self.upsample(x)
|
|
x = torch.cat([x, skip], dim=1)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
class GinkaBottleneck(nn.Module):
|
|
def __init__(self, in_channels, out_channels, attention=False):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(out_channels),
|
|
nn.GELU(),
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(out_channels),
|
|
)
|
|
if attention:
|
|
self.conv.append(SEBlock(out_channels))
|
|
self.conv.append(nn.GELU())
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class GinkaUNet(nn.Module):
|
|
def __init__(self, in_ch=64, out_ch=32):
|
|
"""Ginka Model UNet 部分
|
|
"""
|
|
super().__init__()
|
|
self.down1 = GinkaEncoder(in_ch, in_ch*2, attention=True)
|
|
self.down2 = GinkaEncoder(in_ch*2, in_ch*4, attention=True)
|
|
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
|
|
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
|
|
|
|
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
|
|
|
|
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
|
|
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
|
|
self.up3 = GinkaDecoder(in_ch*4, in_ch*2, attention=True)
|
|
self.up4 = GinkaDecoder(in_ch*2, in_ch, attention=True)
|
|
|
|
self.final = nn.Sequential(
|
|
nn.Conv2d(in_ch, out_ch, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x_down1, skip1 = self.down1(x)
|
|
x_down2, skip2 = self.down2(x_down1)
|
|
x_down3, skip3 = self.down3(x_down2)
|
|
x_down4, skip4 = self.down4(x_down3)
|
|
|
|
x = self.bottleneck(x_down4)
|
|
|
|
x = self.up1(x, skip4)
|
|
x = self.up2(x, skip3)
|
|
x = self.up3(x, skip2)
|
|
x = self.up4(x, skip1)
|
|
|
|
return self.final(x)
|