ginka-generator/ginka/model/unet.py

129 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class GinkaAdaIN(nn.Module):
def __init__(self, num_features, condition_dim):
"""
自适应实例归一化 (AdaIN)
参数:
num_features: 归一化的通道数
condition_dim: 条件输入的特征维度
"""
super(GinkaAdaIN, self).__init__()
self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β
def forward(self, x, condition):
"""
x: [B, C, H, W] - 输入特征图
condition: [B, condition_dim] - 需要注入的条件向量
"""
gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β
gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状
beta = beta.view(x.shape[0], x.shape[1], 1, 1)
x = F.instance_norm(x) # 标准化
return gamma * x + beta # 进行变换
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class AdaINConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, feat_dim):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.adain = GinkaAdaIN(out_ch, feat_dim)
def forward(self, x, feat):
x = self.conv(x)
x = self.adain(x, feat)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch, feat_dim):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.adain = GinkaAdaIN(out_ch, feat_dim)
def forward(self, x, feat):
x = self.conv(x)
x = self.pool(x)
x = self.adain(x, feat)
return x
class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.BatchNorm2d(out_ch),
nn.GELU(),
)
def forward(self, x):
return self.conv(x)
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_ch, out_ch, feat_dim):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.adain = GinkaAdaIN(out_ch, feat_dim)
def forward(self, x, skip, feat):
x = self.upsample(x)
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
x = self.adain(x, feat)
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
"""Ginka Model UNet 部分
"""
super().__init__()
self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim)
self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim)
self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim)
self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim)
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim)
self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim)
self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim)
self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim)
self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
)
def forward(self, x, feat):
x1 = self.in_conv(x, feat)
x2 = self.down1(x1, feat)
x3 = self.down2(x2, feat)
x4 = self.down3(x3, feat)
x5 = self.bottleneck(x4, feat)
x = self.up1(x5, x4, feat)
x = self.up2(x, x3, feat)
x = self.up3(x, x2, feat)
x = self.up4(x, x1, feat)
return self.final(x)