mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-18 07:31:11 +08:00
129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
class GinkaAdaIN(nn.Module):
|
||
def __init__(self, num_features, condition_dim):
|
||
"""
|
||
自适应实例归一化 (AdaIN)
|
||
参数:
|
||
num_features: 归一化的通道数
|
||
condition_dim: 条件输入的特征维度
|
||
"""
|
||
super(GinkaAdaIN, self).__init__()
|
||
self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β
|
||
|
||
def forward(self, x, condition):
|
||
"""
|
||
x: [B, C, H, W] - 输入特征图
|
||
condition: [B, condition_dim] - 需要注入的条件向量
|
||
"""
|
||
gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β
|
||
gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状
|
||
beta = beta.view(x.shape[0], x.shape[1], 1, 1)
|
||
|
||
x = F.instance_norm(x) # 标准化
|
||
return gamma * x + beta # 进行变换
|
||
|
||
class ConvBlock(nn.Module):
|
||
def __init__(self, in_ch, out_ch):
|
||
super().__init__()
|
||
self.conv = nn.Sequential(
|
||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||
nn.BatchNorm2d(out_ch),
|
||
nn.ReLU(),
|
||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||
nn.BatchNorm2d(out_ch),
|
||
nn.ReLU(),
|
||
)
|
||
|
||
def forward(self, x):
|
||
return self.conv(x)
|
||
|
||
class AdaINConvBlock(nn.Module):
|
||
def __init__(self, in_ch, out_ch, feat_dim):
|
||
super().__init__()
|
||
self.conv = ConvBlock(in_ch, out_ch)
|
||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||
|
||
def forward(self, x, feat):
|
||
x = self.conv(x)
|
||
x = self.adain(x, feat)
|
||
return x
|
||
|
||
class GinkaEncoder(nn.Module):
|
||
"""编码器(下采样)部分"""
|
||
def __init__(self, in_ch, out_ch, feat_dim):
|
||
super().__init__()
|
||
self.conv = ConvBlock(in_ch, out_ch)
|
||
self.pool = nn.MaxPool2d(2)
|
||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||
|
||
def forward(self, x, feat):
|
||
x = self.conv(x)
|
||
x = self.pool(x)
|
||
x = self.adain(x, feat)
|
||
return x
|
||
|
||
class GinkaUpSample(nn.Module):
|
||
def __init__(self, in_ch, out_ch):
|
||
super().__init__()
|
||
self.conv = nn.Sequential(
|
||
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
||
nn.BatchNorm2d(out_ch),
|
||
nn.GELU(),
|
||
)
|
||
|
||
def forward(self, x):
|
||
return self.conv(x)
|
||
|
||
class GinkaDecoder(nn.Module):
|
||
"""解码器(上采样)部分"""
|
||
def __init__(self, in_ch, out_ch, feat_dim):
|
||
super().__init__()
|
||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||
self.conv = ConvBlock(in_ch, out_ch)
|
||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||
|
||
def forward(self, x, skip, feat):
|
||
x = self.upsample(x)
|
||
x = torch.cat([x, skip], dim=1)
|
||
x = self.conv(x)
|
||
x = self.adain(x, feat)
|
||
return x
|
||
|
||
class GinkaUNet(nn.Module):
|
||
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
|
||
"""Ginka Model UNet 部分
|
||
"""
|
||
super().__init__()
|
||
self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim)
|
||
self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim)
|
||
self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim)
|
||
self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim)
|
||
|
||
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim)
|
||
|
||
self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim)
|
||
self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim)
|
||
self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim)
|
||
self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim)
|
||
|
||
self.final = nn.Sequential(
|
||
nn.Conv2d(base_ch, out_ch, 1),
|
||
)
|
||
|
||
def forward(self, x, feat):
|
||
x1 = self.in_conv(x, feat)
|
||
x2 = self.down1(x1, feat)
|
||
x3 = self.down2(x2, feat)
|
||
x4 = self.down3(x3, feat)
|
||
x5 = self.bottleneck(x4, feat)
|
||
|
||
x = self.up1(x5, x4, feat)
|
||
x = self.up2(x, x3, feat)
|
||
x = self.up3(x, x2, feat)
|
||
x = self.up4(x, x1, feat)
|
||
|
||
return self.final(x)
|