ginka-generator/ginka/model/unet.py

114 lines
3.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class ConditionFusionBlock(nn.Module):
def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数
def forward(self, x, cond_feat):
return x + self.alpha * cond_feat # 残差融合
class FusionConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.fusion = ConditionFusionBlock()
def forward(self, x, feat):
x = self.conv(x)
x = self.fusion(x, feat)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.fusion = ConditionFusionBlock()
def forward(self, x, feat):
x = self.conv(x)
x = self.pool(x)
x = self.fusion(x, feat)
return x
class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.BatchNorm2d(out_ch),
nn.GELU(),
)
def forward(self, x):
return self.conv(x)
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.fusion = ConditionFusionBlock()
def forward(self, x, skip, feat):
dec = self.upsample(x)
x = torch.cat([dec, skip], dim=1)
x = self.conv(x)
x = self.fusion(x, feat)
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
"""Ginka Model UNet 部分
"""
super().__init__()
self.in_conv = FusionConvBlock(in_ch, base_ch)
self.down1 = GinkaEncoder(base_ch, base_ch*2)
self.down2 = GinkaEncoder(base_ch*2, base_ch*4)
self.down3 = GinkaEncoder(base_ch*4, base_ch*8)
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16)
self.up1 = GinkaDecoder(base_ch*16, base_ch*8)
self.up2 = GinkaDecoder(base_ch*8, base_ch*4)
self.up3 = GinkaDecoder(base_ch*4, base_ch*2)
self.up4 = GinkaDecoder(base_ch*2, base_ch)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
)
def forward(self, x, feat, cond):
x1 = self.in_conv(x, feat[0])
x2 = self.down1(x1, feat[1])
x3 = self.down2(x2, feat[2])
x4 = self.down3(x3, feat[3])
x5 = self.bottleneck(x4, feat[4])
x = self.up1(x5, x4, feat[3])
x = self.up2(x, x3, feat[2])
x = self.up3(x, x2, feat[1])
x = self.up4(x, x1, feat[0])
return self.final(x)