mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
perf: 模型微调
This commit is contained in:
parent
21b693ec21
commit
fa48863946
@ -19,11 +19,11 @@ class DoubleConvBlock(nn.Module):
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(feats[1]),
|
||||
nn.ELU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(feats[2]),
|
||||
nn.ELU(),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -57,11 +57,11 @@ class GCNBlock(nn.Module):
|
||||
|
||||
# GCN forward
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
x = F.gelu(self.norm1(x))
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
x = F.gelu(self.norm2(x))
|
||||
x = self.conv3(x, edge_index)
|
||||
x = F.elu(self.norm3(x))
|
||||
x = F.gelu(self.norm3(x))
|
||||
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
@ -92,9 +92,9 @@ class TransformerGCNBlock(nn.Module):
|
||||
|
||||
# GCN forward
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
x = F.gelu(self.norm1(x))
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
x = F.gelu(self.norm2(x))
|
||||
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
@ -104,8 +104,8 @@ class ConvFusionModule(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
|
||||
super().__init__()
|
||||
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
|
||||
self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch])
|
||||
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.cnn(x)
|
||||
@ -120,11 +120,11 @@ class DoubleFCModule(nn.Module):
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(hidden_dim, out_dim),
|
||||
nn.LayerNorm(out_dim),
|
||||
nn.ELU()
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -6,22 +6,22 @@ from .common import DoubleFCModule
|
||||
class ConditionEncoder(nn.Module):
|
||||
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim)
|
||||
self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim)
|
||||
self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim)
|
||||
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
|
||||
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
|
||||
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||
batch_first=True
|
||||
),
|
||||
num_layers=6
|
||||
num_layers=4
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim*2),
|
||||
nn.LayerNorm(hidden_dim*2),
|
||||
nn.ELU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(hidden_dim*2, out_dim)
|
||||
nn.Linear(hidden_dim, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, tag, val, stage):
|
||||
@ -38,18 +38,10 @@ class ConditionInjector(nn.Module):
|
||||
def __init__(self, cond_dim, out_dim):
|
||||
super().__init__()
|
||||
self.gamma_layer = nn.Sequential(
|
||||
nn.Linear(cond_dim, cond_dim*2),
|
||||
nn.LayerNorm(cond_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(cond_dim*2, out_dim)
|
||||
nn.Linear(cond_dim, out_dim)
|
||||
)
|
||||
self.beta_layer = nn.Sequential(
|
||||
nn.Linear(cond_dim, cond_dim*2),
|
||||
nn.LayerNorm(cond_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(cond_dim*2, out_dim)
|
||||
nn.Linear(cond_dim, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
|
||||
@ -2,22 +2,138 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import global_max_pool, GCNConv
|
||||
from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv
|
||||
from torch_geometric.utils import grid
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from ..common.cond import ConditionEncoder
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
|
||||
def batch_edge_index(B, edge_index, num_nodes_per_batch):
|
||||
# 批次偏移 edge_index
|
||||
edge_index = edge_index.clone() # [2, E]
|
||||
batch_edge_index = []
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
|
||||
class DoubleConvBlock(nn.Module):
|
||||
def __init__(self, feats: tuple[int, int, int]):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')),
|
||||
nn.GELU(),
|
||||
|
||||
spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cnn(x)
|
||||
return x
|
||||
|
||||
class TransformerGCNBlock(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
|
||||
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
|
||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
||||
device = x.device
|
||||
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.gelu(x)
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.gelu(x)
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
class ConvFusionModule(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
|
||||
super().__init__()
|
||||
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
|
||||
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.cnn(x)
|
||||
x2 = self.gcn(x)
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
return x
|
||||
|
||||
class DoubleFCModule(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, hidden_dim)),
|
||||
nn.GELU(),
|
||||
|
||||
spectral_norm(nn.Linear(hidden_dim, out_dim)),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class ConditionEncoder(nn.Module):
|
||||
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
|
||||
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
|
||||
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||
batch_first=True
|
||||
),
|
||||
num_layers=4
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
|
||||
nn.GELU(),
|
||||
|
||||
spectral_norm(nn.Linear(hidden_dim, out_dim))
|
||||
)
|
||||
|
||||
def forward(self, tag, val, stage):
|
||||
tag = self.tag_embed(tag)
|
||||
val = self.val_embed(val)
|
||||
stage = self.stage_embed(stage)
|
||||
feat = torch.stack([tag, val, stage], dim=1)
|
||||
feat = self.encoder(feat)
|
||||
feat = torch.mean(feat, dim=1)
|
||||
feat = self.fusion(feat)
|
||||
return feat
|
||||
|
||||
class ConditionInjector(nn.Module):
|
||||
def __init__(self, cond_dim, out_dim):
|
||||
super().__init__()
|
||||
self.gamma_layer = nn.Sequential(
|
||||
spectral_norm(nn.Linear(cond_dim, out_dim))
|
||||
)
|
||||
self.beta_layer = nn.Sequential(
|
||||
spectral_norm(nn.Linear(cond_dim, out_dim))
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
|
||||
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
|
||||
return x * gamma + beta
|
||||
|
||||
class CNNHead(nn.Module):
|
||||
def __init__(self, in_ch):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.GELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d((2, 2))
|
||||
)
|
||||
@ -46,7 +162,7 @@ class GCNHead(nn.Module):
|
||||
|
||||
def forward(self, x, graph, cond):
|
||||
x = self.gcn(x, graph.edge_index)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = F.gelu(x)
|
||||
x = global_max_pool(x, graph.batch)
|
||||
cond = self.proj(cond)
|
||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
||||
@ -91,6 +207,65 @@ class MinamoModel(nn.Module):
|
||||
raise RuntimeError("Unknown critic stage.")
|
||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
||||
return score, vision_score, topo_score
|
||||
|
||||
class MinamoHead2(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch):
|
||||
super().__init__()
|
||||
self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13)
|
||||
self.pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.proj = spectral_norm(nn.Linear(256, hidden_ch))
|
||||
self.fc = spectral_norm(nn.Linear(hidden_ch, 1))
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x = x.squeeze(3).squeeze(2)
|
||||
cond = self.proj(cond)
|
||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
||||
x = self.fc(x) + proj
|
||||
return x
|
||||
|
||||
class MinamoModel2(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||
|
||||
self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13)
|
||||
self.conv2 = ConvFusionModule(128, 256, 256, 13, 13)
|
||||
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
|
||||
|
||||
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
|
||||
self.head1 = MinamoHead2(256, 256)
|
||||
self.head2 = MinamoHead2(256, 256)
|
||||
self.head3 = MinamoHead2(256, 256)
|
||||
|
||||
self.inject1 = ConditionInjector(256, 128)
|
||||
self.inject2 = ConditionInjector(256, 256)
|
||||
self.inject3 = ConditionInjector(256, 256)
|
||||
|
||||
def forward(self, x, stage, tag_cond, val_cond):
|
||||
B, D = tag_cond.shape
|
||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||
x = self.conv1(x)
|
||||
x = self.inject1(x, cond)
|
||||
x = self.conv2(x)
|
||||
x = self.inject2(x, cond)
|
||||
x = self.conv3(x)
|
||||
x = self.inject3(x, cond)
|
||||
|
||||
if stage == 0:
|
||||
score = self.head0(x, cond)
|
||||
elif stage == 1:
|
||||
score = self.head1(x, cond)
|
||||
elif stage == 2:
|
||||
score = self.head2(x, cond)
|
||||
elif stage == 3:
|
||||
score = self.head3(x, cond)
|
||||
else:
|
||||
raise RuntimeError("Unknown critic stage.")
|
||||
|
||||
return score
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
@ -99,19 +274,19 @@ if __name__ == "__main__":
|
||||
val = torch.rand(1, 16).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = MinamoModel().cuda()
|
||||
model = MinamoModel2().cuda()
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val)
|
||||
output = model(input, 1, tag, val)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
# print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
|
||||
# print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
|
||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
||||
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
|
||||
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
|
||||
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -2,12 +2,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import GATConv
|
||||
from torch_geometric.nn import GATConv, TransformerConv
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
def __init__(
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512
|
||||
):
|
||||
super().__init__()
|
||||
# 传入 softmax 概率值,直接映射
|
||||
@ -16,9 +16,9 @@ class MinamoTopoModel(nn.Module):
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
# 图卷积层
|
||||
self.conv1 = GATConv(emb_dim, hidden_dim, heads=8)
|
||||
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
|
||||
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
|
||||
self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8)
|
||||
self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8)
|
||||
self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.input_proj(graph.x)
|
||||
|
||||
@ -10,13 +10,10 @@ class MinamoVisionModel(nn.Module):
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
|
||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 5*5
|
||||
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
|
||||
|
||||
@ -142,13 +142,14 @@ class GinkaWGANDataset(Dataset):
|
||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
_, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
|
||||
return {
|
||||
"real1": removed1,
|
||||
"masked1": rand,
|
||||
"real2": removed2,
|
||||
"masked2": torch.zeros_like(target),
|
||||
"masked2": masked,
|
||||
"real3": removed3,
|
||||
"masked3": torch.zeros_like(target),
|
||||
"tag_cond": tag_cond,
|
||||
|
||||
@ -2,24 +2,25 @@ import torch
|
||||
import torch.nn as nn
|
||||
from ..common.common import ConvFusionModule
|
||||
from ..common.cond import ConditionInjector
|
||||
from .unet import GinkaEncoderPath, GinkaDecoderPath
|
||||
|
||||
class RandomInputHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.enc = ConvFusionModule(32, 256, 256, 32, 32)
|
||||
self.enc = GinkaEncoderPath(32, 32)
|
||||
self.dec = GinkaDecoderPath(32)
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(128),
|
||||
nn.ELU(),
|
||||
nn.AdaptiveMaxPool2d((15, 15)),
|
||||
nn.Conv2d(32, 64, 3, padding=0),
|
||||
nn.InstanceNorm2d(64),
|
||||
nn.GELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d((13, 13)),
|
||||
nn.Conv2d(128, 32, 1),
|
||||
nn.Conv2d(64, 32, 1),
|
||||
)
|
||||
self.inject = ConditionInjector(256, 256)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.enc(x)
|
||||
x = self.inject(x, cond)
|
||||
x1, x2, x3, x4 = self.enc(x, cond)
|
||||
x = self.dec(x1, x2, x3, x4, cond)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
@ -28,15 +29,12 @@ class InputUpsample(nn.Module):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
|
||||
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
|
||||
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x): # [B, C, 13, 13]
|
||||
@ -47,18 +45,14 @@ class GinkaInput(nn.Module):
|
||||
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
||||
super().__init__()
|
||||
self.out_size = out_size
|
||||
self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1])
|
||||
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
|
||||
self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1])
|
||||
self.inject1 = ConditionInjector(256, in_ch)
|
||||
self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1])
|
||||
self.inject1 = ConditionInjector(256, out_ch)
|
||||
self.inject2 = ConditionInjector(256, out_ch)
|
||||
self.inject3 = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.enc1(x)
|
||||
x = self.inject1(x, cond)
|
||||
x = self.upsample(x)
|
||||
x = self.inject1(x, cond)
|
||||
x = self.enc(x)
|
||||
x = self.inject2(x, cond)
|
||||
x = self.enc2(x)
|
||||
x = self.inject3(x, cond)
|
||||
return x
|
||||
|
||||
@ -1,12 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.data import Data
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from ..critic.model import MinamoModel
|
||||
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 30
|
||||
@ -156,15 +151,15 @@ def entrance_constraint_loss(
|
||||
)
|
||||
return total_loss
|
||||
|
||||
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)):
|
||||
def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
|
||||
C = input_map.shape[1]
|
||||
mask = torch.ones(C, device=input_map.device)
|
||||
mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1
|
||||
illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel()
|
||||
|
||||
return illegal_class_penalty
|
||||
unallowed = get_not_allowed(allowed_classes, include_illegal=True)
|
||||
illegal = input_map[:, unallowed, :, :]
|
||||
penalty = torch.sum(illegal)
|
||||
|
||||
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
|
||||
return penalty
|
||||
|
||||
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]):
|
||||
wall_prob = input_map[:, wall_class] # [B, H, W]
|
||||
wall_ratio = wall_prob.mean() # 计算平均墙体占比
|
||||
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
|
||||
@ -241,6 +236,16 @@ def immutable_penalty_loss(
|
||||
|
||||
return penalty
|
||||
|
||||
def modifiable_penalty_loss(
|
||||
probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
||||
) -> torch.Tensor:
|
||||
target_modifiable = input[:, modifiable_classes, :, :]
|
||||
pred_modifiable = probs[:, modifiable_classes, :, :]
|
||||
existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0)
|
||||
penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device))
|
||||
|
||||
return penalty
|
||||
|
||||
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
|
||||
not_allowed = get_not_allowed(legal_classes, include_illegal=True)
|
||||
input_mask = pred[:, not_allowed, :, :]
|
||||
@ -249,43 +254,40 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]):
|
||||
# weight:
|
||||
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
|
||||
# 2. CE 损失
|
||||
# 3. 不可修改类型损失和非法图块损失
|
||||
# 4. 图块类型损失
|
||||
# 5. 入口存在性损失
|
||||
# 6. 多样性损失
|
||||
# 7. 密度损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
|
||||
# 进行插值
|
||||
batch_size = real_data.size(0)
|
||||
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
|
||||
epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device)
|
||||
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
|
||||
interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device)
|
||||
|
||||
# 对图像进行反向传播并计算梯度
|
||||
interp_data.requires_grad_()
|
||||
interp_graph.x.requires_grad_()
|
||||
|
||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond)
|
||||
d_score = critic(interp_data, stage, tag_cond, val_cond)
|
||||
|
||||
# 计算梯度
|
||||
grad_vis = torch.autograd.grad(
|
||||
outputs=d_vis_score, inputs=interp_data,
|
||||
grad_outputs=torch.ones_like(d_vis_score),
|
||||
create_graph=True, retain_graph=True, only_inputs=True
|
||||
)[0]
|
||||
grad_topo = torch.autograd.grad(
|
||||
outputs=d_topo_score, inputs=interp_graph.x,
|
||||
grad_outputs=torch.ones_like(d_topo_score),
|
||||
grad = torch.autograd.grad(
|
||||
outputs=d_score, inputs=interp_data,
|
||||
grad_outputs=torch.ones_like(d_score),
|
||||
create_graph=True, retain_graph=True, only_inputs=True
|
||||
)[0]
|
||||
|
||||
# 计算梯度的 L2 范数
|
||||
grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1)
|
||||
grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1)
|
||||
grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1)
|
||||
# 计算梯度惩罚项
|
||||
gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean()
|
||||
gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean()
|
||||
gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT
|
||||
gp_loss = ((grad_norm - 1.0) ** 2).mean()
|
||||
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
|
||||
|
||||
return gp_loss
|
||||
@ -296,10 +298,8 @@ class WGANGinkaLoss:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
""" 判别器损失函数 """
|
||||
fake_data = F.softmax(fake_data, dim=1)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||
real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond)
|
||||
real_scores = critic(real_data, stage, tag_cond, val_cond)
|
||||
fake_scores = critic(fake_data, stage, tag_cond, val_cond)
|
||||
|
||||
# Wasserstein 距离
|
||||
d_loss = fake_scores.mean() - real_scores.mean()
|
||||
@ -312,10 +312,9 @@ class WGANGinkaLoss:
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" 生成器损失函数 """
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
@ -343,9 +342,8 @@ class WGANGinkaLoss:
|
||||
|
||||
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
@ -370,10 +368,9 @@ class WGANGinkaLoss:
|
||||
|
||||
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
@ -395,13 +392,15 @@ class WGANGinkaLoss:
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor:
|
||||
def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs = F.softmax(map, dim=1)
|
||||
head_scores = critic(probs, 0, tag_cond, val_cond)
|
||||
probs_a, probs_b = probs.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
torch.mean(head_scores),
|
||||
input_head_illegal_loss(probs),
|
||||
input_head_wall_loss(probs),
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.3
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
@ -1,22 +1,15 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..common.common import GCNBlock, DoubleConvBlock
|
||||
from ..common.common import ConvFusionModule
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class StageHead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch])
|
||||
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
|
||||
self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32)
|
||||
self.pool = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32),
|
||||
ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
|
||||
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
@ -24,10 +17,7 @@ class StageHead(nn.Module):
|
||||
self.inject = ConditionInjector(256, in_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x_cnn = self.cnn_head(x)
|
||||
x_gcn = self.gcn_head(x)
|
||||
x = torch.cat([x_cnn, x_gcn], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.dec(x)
|
||||
x = self.inject(x, cond)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from shared.attention import ChannelAttention
|
||||
from ..common.common import GCNBlock, TransformerGCNBlock
|
||||
from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class GinkaTransformerEncoder(nn.Module):
|
||||
@ -37,16 +37,17 @@ class GinkaTransformerEncoder(nn.Module):
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, attn=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
)
|
||||
if attn:
|
||||
self.conv.append(ChannelAttention(out_ch))
|
||||
self.conv.append(nn.ELU())
|
||||
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
|
||||
# self.conv = nn.Sequential(
|
||||
# nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
# nn.InstanceNorm2d(out_ch),
|
||||
# nn.ELU(),
|
||||
# nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
# nn.InstanceNorm2d(out_ch),
|
||||
# )
|
||||
# if attn:
|
||||
# self.conv.append(ChannelAttention(out_ch))
|
||||
# self.conv.append(nn.ELU())
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
@ -64,47 +65,24 @@ class FusionModule(nn.Module):
|
||||
class GinkaUNetInput(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, in_ch)
|
||||
self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h)
|
||||
self.fusion = ConvBlock(in_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x1 = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaGCNFusedEncoder(nn.Module):
|
||||
class GinkaEncoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x2 = self.gcn(x)
|
||||
x = self.fusion(x, x2)
|
||||
x = self.conv(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
@ -114,42 +92,29 @@ class GinkaUpSample(nn.Module):
|
||||
self.conv = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaDecoder(nn.Module):
|
||||
"""解码器(上采样)部分"""
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, feat, cond):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.conv(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaGCNFusedDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
self.fusion = nn.Conv2d(in_ch, in_ch, 1)
|
||||
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, feat, cond):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
x = self.fusion(x, x2)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
@ -162,58 +127,62 @@ class GinkaBottleneck(nn.Module):
|
||||
# )
|
||||
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
||||
self.conv = ConvBlock(module_ch, module_ch)
|
||||
self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h)
|
||||
self.fusion = nn.Conv2d(module_ch*2, module_ch, 1)
|
||||
self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
|
||||
self.inject = ConditionInjector(256, module_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
B = x.size(0)
|
||||
|
||||
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
||||
# x1 = self.transformer(x1)
|
||||
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
||||
x1 = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.conv(x)
|
||||
x = self.inject(x, cond)
|
||||
|
||||
return x
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=32, base_ch=64, out_ch=32):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
|
||||
class GinkaEncoderPath(nn.Module):
|
||||
def __init__(self, in_ch, base_ch):
|
||||
super().__init__()
|
||||
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
|
||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4)
|
||||
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
||||
|
||||
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
|
||||
self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
|
||||
self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(base_ch, out_ch, 1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x1 = self.down1(x, cond) # [B, 64, 32, 32]
|
||||
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
|
||||
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
|
||||
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
|
||||
|
||||
# 上采样
|
||||
return x1, x2, x3, x4
|
||||
|
||||
class GinkaDecoderPath(nn.Module):
|
||||
def __init__(self, base_ch):
|
||||
super().__init__()
|
||||
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
|
||||
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
|
||||
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
|
||||
|
||||
def forward(self, x1, x2, x3, x4, cond):
|
||||
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
|
||||
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
|
||||
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
|
||||
return x
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.enc = GinkaEncoderPath(in_ch, base_ch)
|
||||
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
||||
self.dec = GinkaDecoderPath(base_ch)
|
||||
|
||||
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x1, x2, x3, x4 = self.enc(x, cond)
|
||||
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
|
||||
x = self.dec(x1, x2, x3, x4, cond)
|
||||
|
||||
x = self.final(x) # [B, 32, 32, 32]
|
||||
|
||||
return x
|
||||
|
||||
@ -6,12 +6,13 @@ import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .generator.model import GinkaModel
|
||||
from .dataset import GinkaWGANDataset
|
||||
from .generator.loss import WGANGinkaLoss
|
||||
from .critic.model import MinamoModel
|
||||
from .critic.model import MinamoModel2
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# 标签定义:
|
||||
@ -105,7 +106,7 @@ def train():
|
||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||
|
||||
ginka = GinkaModel().to(device)
|
||||
minamo = MinamoModel().to(device)
|
||||
minamo = MinamoModel2().to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||
@ -113,7 +114,7 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
||||
@ -201,14 +202,24 @@ def train():
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
if train_stage == 4:
|
||||
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond)
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
|
||||
|
||||
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||
dis = [dis1, dis2, dis3]
|
||||
loss_d = [loss_d1, loss_d2, loss_d3]
|
||||
|
||||
if train_stage == 4:
|
||||
dis.append(dis0)
|
||||
loss_d.append(loss_d0)
|
||||
|
||||
dis_avg = sum(dis) / len(dis)
|
||||
loss_d_avg = sum(loss_d) / len(loss_d)
|
||||
|
||||
# 反向传播
|
||||
loss_d_avg.backward()
|
||||
@ -230,7 +241,7 @@ def train():
|
||||
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
|
||||
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
|
||||
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
||||
|
||||
loss_g.backward()
|
||||
@ -240,19 +251,16 @@ def train():
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
|
||||
|
||||
if train_stage == 3:
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond)
|
||||
else:
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond)
|
||||
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
|
||||
|
||||
if train_stage == 4:
|
||||
loss_head = criterion.generator_input_head_loss(x_in)
|
||||
loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond)
|
||||
loss_head.backward(retain_graph=True)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
@ -286,6 +294,8 @@ def train():
|
||||
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
||||
|
||||
idx = 0
|
||||
gap = 5
|
||||
color = (255, 255, 255) # 白色
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
real1 = batch["real1"].to(device)
|
||||
@ -301,17 +311,42 @@ def train():
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
x_in = torch.argmax(x_in, dim=1).cpu().numpy()
|
||||
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
||||
|
||||
masked1 = torch.argmax(masked1, dim=1).cpu().numpy()
|
||||
masked2 = torch.argmax(masked2, dim=1).cpu().numpy()
|
||||
masked3 = torch.argmax(masked3, dim=1).cpu().numpy()
|
||||
|
||||
for i in range(fake1.shape[0]):
|
||||
for key, one in enumerate([fake1, fake2, fake3]):
|
||||
map_matrix = one[i]
|
||||
image = matrix_to_image_cv(map_matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
||||
fake1_img = matrix_to_image_cv(fake1[i], tile_dict)
|
||||
fake2_img = matrix_to_image_cv(fake2[i], tile_dict)
|
||||
fake3_img = matrix_to_image_cv(fake3[i], tile_dict)
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||
hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线
|
||||
in1_img = matrix_to_image_cv(masked1[i], tile_dict)
|
||||
in2_img = matrix_to_image_cv(masked2[i], tile_dict)
|
||||
in3_img = matrix_to_image_cv(masked3[i], tile_dict)
|
||||
img = np.block([
|
||||
[[in1_img], [vline], [in2_img], [vline], [in3_img]],
|
||||
[[hline]],
|
||||
[[fake1_img], [vline], [fake2_img], [vline], [fake3_img]]
|
||||
])
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
|
||||
in_img = matrix_to_image_cv(x_in[i], tile_dict)
|
||||
img = np.block([
|
||||
[[in_img], [vline], [fake1_img]],
|
||||
[[hline]],
|
||||
[[fake2_img], [vline], [fake3_img]]
|
||||
])
|
||||
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", img)
|
||||
|
||||
idx += 1
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user