From fa48863946eebf9161ab7e533f062542815628ce Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 11 May 2025 23:50:08 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=A8=A1=E5=9E=8B=E5=BE=AE=E8=B0=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/common/common.py | 22 ++--- ginka/common/cond.py | 28 ++---- ginka/critic/model.py | 193 ++++++++++++++++++++++++++++++++++++-- ginka/critic/topo.py | 10 +- ginka/critic/vision.py | 7 +- ginka/dataset.py | 3 +- ginka/generator/input.py | 34 +++---- ginka/generator/loss.py | 91 +++++++++--------- ginka/generator/output.py | 20 +--- ginka/generator/unet.py | 151 ++++++++++++----------------- ginka/train_wgan.py | 75 +++++++++++---- 11 files changed, 393 insertions(+), 241 deletions(-) diff --git a/ginka/common/common.py b/ginka/common/common.py index 59e70b0..2ac3c7c 100644 --- a/ginka/common/common.py +++ b/ginka/common/common.py @@ -19,11 +19,11 @@ class DoubleConvBlock(nn.Module): self.cnn = nn.Sequential( nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(feats[1]), - nn.ELU(), + nn.GELU(), nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(feats[2]), - nn.ELU(), + nn.GELU(), ) def forward(self, x): @@ -57,11 +57,11 @@ class GCNBlock(nn.Module): # GCN forward x = self.conv1(x, edge_index) - x = F.elu(self.norm1(x)) + x = F.gelu(self.norm1(x)) x = self.conv2(x, edge_index) - x = F.elu(self.norm2(x)) + x = F.gelu(self.norm2(x)) x = self.conv3(x, edge_index) - x = F.elu(self.norm3(x)) + x = F.gelu(self.norm3(x)) # Reshape back to [B, C, H, W] x = x.view(B, H, W, -1).permute(0, 3, 1, 2) @@ -92,9 +92,9 @@ class TransformerGCNBlock(nn.Module): # GCN forward x = self.conv1(x, edge_index) - x = F.elu(self.norm1(x)) + x = F.gelu(self.norm1(x)) x = self.conv2(x, edge_index) - x = F.elu(self.norm2(x)) + x = F.gelu(self.norm2(x)) # Reshape back to [B, C, H, W] x = x.view(B, H, W, -1).permute(0, 3, 1, 2) @@ -104,8 +104,8 @@ class ConvFusionModule(nn.Module): def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): super().__init__() self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch]) - self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h) - self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch]) + self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h) + self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch]) def forward(self, x): x1 = self.cnn(x) @@ -120,11 +120,11 @@ class DoubleFCModule(nn.Module): self.fc = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.LayerNorm(hidden_dim), - nn.ELU(), + nn.GELU(), nn.Linear(hidden_dim, out_dim), nn.LayerNorm(out_dim), - nn.ELU() + nn.GELU() ) def forward(self, x): diff --git a/ginka/common/cond.py b/ginka/common/cond.py index bc972b6..485bad5 100644 --- a/ginka/common/cond.py +++ b/ginka/common/cond.py @@ -6,22 +6,22 @@ from .common import DoubleFCModule class ConditionEncoder(nn.Module): def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): super().__init__() - self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim) - self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim) - self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim) + self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim) + self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim) + self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, batch_first=True ), - num_layers=6 + num_layers=4 ) self.fusion = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim*2), - nn.LayerNorm(hidden_dim*2), - nn.ELU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), - nn.Linear(hidden_dim*2, out_dim) + nn.Linear(hidden_dim, out_dim) ) def forward(self, tag, val, stage): @@ -38,18 +38,10 @@ class ConditionInjector(nn.Module): def __init__(self, cond_dim, out_dim): super().__init__() self.gamma_layer = nn.Sequential( - nn.Linear(cond_dim, cond_dim*2), - nn.LayerNorm(cond_dim*2), - nn.ELU(), - - nn.Linear(cond_dim*2, out_dim) + nn.Linear(cond_dim, out_dim) ) self.beta_layer = nn.Sequential( - nn.Linear(cond_dim, cond_dim*2), - nn.LayerNorm(cond_dim*2), - nn.ELU(), - - nn.Linear(cond_dim*2, out_dim) + nn.Linear(cond_dim, out_dim) ) def forward(self, x, cond): diff --git a/ginka/critic/model.py b/ginka/critic/model.py index 72415fd..548d451 100644 --- a/ginka/critic/model.py +++ b/ginka/critic/model.py @@ -2,22 +2,138 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm -from torch_geometric.nn import global_max_pool, GCNConv +from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv +from torch_geometric.utils import grid from shared.constant import VISION_WEIGHT, TOPO_WEIGHT -from shared.graph import batch_convert_soft_map_to_graph from .vision import MinamoVisionModel from .topo import MinamoTopoModel -from ..common.cond import ConditionEncoder def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") +def batch_edge_index(B, edge_index, num_nodes_per_batch): + # 批次偏移 edge_index + edge_index = edge_index.clone() # [2, E] + batch_edge_index = [] + for i in range(B): + offset = i * num_nodes_per_batch + batch_edge_index.append(edge_index + offset) + return torch.cat(batch_edge_index, dim=1) + +class DoubleConvBlock(nn.Module): + def __init__(self, feats: tuple[int, int, int]): + super().__init__() + self.cnn = nn.Sequential( + spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')), + nn.GELU(), + + spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')), + nn.GELU(), + ) + + def forward(self, x): + x = self.cnn(x) + return x + +class TransformerGCNBlock(nn.Module): + def __init__(self, in_ch, hidden_ch, out_ch, w, h): + super().__init__() + self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True) + self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1) + self.single_edge_index, _ = grid(h, w) # [2, E] for a single map + + def forward(self, x): + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) + device = x.device + edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W) + x = self.conv1(x, edge_index) + x = F.gelu(x) + x = self.conv2(x, edge_index) + x = F.gelu(x) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + return x + +class ConvFusionModule(nn.Module): + def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): + super().__init__() + self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch]) + self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h) + self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch]) + + def forward(self, x): + x1 = self.cnn(x) + x2 = self.gcn(x) + x = torch.cat([x1, x2], dim=1) + x = self.fusion(x) + return x + +class DoubleFCModule(nn.Module): + def __init__(self, in_dim, hidden_dim, out_dim): + super().__init__() + self.fc = nn.Sequential( + spectral_norm(nn.Linear(in_dim, hidden_dim)), + nn.GELU(), + + spectral_norm(nn.Linear(hidden_dim, out_dim)), + nn.GELU() + ) + + def forward(self, x): + x = self.fc(x) + return x + +class ConditionEncoder(nn.Module): + def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): + super().__init__() + self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim) + self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim) + self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim) + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, + batch_first=True + ), + num_layers=4 + ) + self.fusion = nn.Sequential( + spectral_norm(nn.Linear(hidden_dim, hidden_dim)), + nn.GELU(), + + spectral_norm(nn.Linear(hidden_dim, out_dim)) + ) + + def forward(self, tag, val, stage): + tag = self.tag_embed(tag) + val = self.val_embed(val) + stage = self.stage_embed(stage) + feat = torch.stack([tag, val, stage], dim=1) + feat = self.encoder(feat) + feat = torch.mean(feat, dim=1) + feat = self.fusion(feat) + return feat + +class ConditionInjector(nn.Module): + def __init__(self, cond_dim, out_dim): + super().__init__() + self.gamma_layer = nn.Sequential( + spectral_norm(nn.Linear(cond_dim, out_dim)) + ) + self.beta_layer = nn.Sequential( + spectral_norm(nn.Linear(cond_dim, out_dim)) + ) + + def forward(self, x, cond): + gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3) + beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3) + return x * gamma + beta + class CNNHead(nn.Module): def __init__(self, in_ch): super().__init__() self.cnn = nn.Sequential( spectral_norm(nn.Conv2d(in_ch, in_ch, 3)), - nn.LeakyReLU(0.2), + nn.GELU(), nn.AdaptiveMaxPool2d((2, 2)) ) @@ -46,7 +162,7 @@ class GCNHead(nn.Module): def forward(self, x, graph, cond): x = self.gcn(x, graph.edge_index) - x = F.leaky_relu(x, 0.2) + x = F.gelu(x) x = global_max_pool(x, graph.batch) cond = self.proj(cond) proj = torch.sum(x * cond, dim=1, keepdim=True) @@ -91,6 +207,65 @@ class MinamoModel(nn.Module): raise RuntimeError("Unknown critic stage.") score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score return score, vision_score, topo_score + +class MinamoHead2(nn.Module): + def __init__(self, in_ch, hidden_ch): + super().__init__() + self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13) + self.pool = nn.AdaptiveMaxPool2d(1) + self.proj = spectral_norm(nn.Linear(256, hidden_ch)) + self.fc = spectral_norm(nn.Linear(hidden_ch, 1)) + + def forward(self, x, cond): + x = self.conv(x) + x = self.pool(x) + x = x.squeeze(3).squeeze(2) + cond = self.proj(cond) + proj = torch.sum(x * cond, dim=1, keepdim=True) + x = self.fc(x) + proj + return x + +class MinamoModel2(nn.Module): + def __init__(self, tile_types=32): + super().__init__() + self.cond = ConditionEncoder(64, 16, 256, 256) + + self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13) + self.conv2 = ConvFusionModule(128, 256, 256, 13, 13) + self.conv3 = ConvFusionModule(256, 512, 256, 13, 13) + + self.head0 = MinamoHead2(256, 256) # 随机头的判别头 + self.head1 = MinamoHead2(256, 256) + self.head2 = MinamoHead2(256, 256) + self.head3 = MinamoHead2(256, 256) + + self.inject1 = ConditionInjector(256, 128) + self.inject2 = ConditionInjector(256, 256) + self.inject3 = ConditionInjector(256, 256) + + def forward(self, x, stage, tag_cond, val_cond): + B, D = tag_cond.shape + stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device) + cond = self.cond(tag_cond, val_cond, stage_tensor) + x = self.conv1(x) + x = self.inject1(x, cond) + x = self.conv2(x) + x = self.inject2(x, cond) + x = self.conv3(x) + x = self.inject3(x, cond) + + if stage == 0: + score = self.head0(x, cond) + elif stage == 1: + score = self.head1(x, cond) + elif stage == 2: + score = self.head2(x, cond) + elif stage == 3: + score = self.head3(x, cond) + else: + raise RuntimeError("Unknown critic stage.") + + return score # 检查显存占用 if __name__ == "__main__": @@ -99,19 +274,19 @@ if __name__ == "__main__": val = torch.rand(1, 16).cuda() # 初始化模型 - model = MinamoModel().cuda() + model = MinamoModel2().cuda() print_memory("初始化后") # 前向传播 - output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val) + output = model(input, 1, tag, val) print_memory("前向传播后") print(f"输入形状: feat={input.shape}") print(f"输出形状: output={output.shape}") + # print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}") + # print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}") print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") - print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}") - print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}") print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/critic/topo.py b/ginka/critic/topo.py index 967f08f..43b7eaf 100644 --- a/ginka/critic/topo.py +++ b/ginka/critic/topo.py @@ -2,12 +2,12 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm -from torch_geometric.nn import GATConv +from torch_geometric.nn import GATConv, TransformerConv from torch_geometric.data import Data class MinamoTopoModel(nn.Module): def __init__( - self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512 + self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512 ): super().__init__() # 传入 softmax 概率值,直接映射 @@ -16,9 +16,9 @@ class MinamoTopoModel(nn.Module): nn.LeakyReLU(0.2) ) # 图卷积层 - self.conv1 = GATConv(emb_dim, hidden_dim, heads=8) - self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8) - self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1) + self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8) + self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8) + self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1) def forward(self, graph: Data): x = self.input_proj(graph.x) diff --git a/ginka/critic/vision.py b/ginka/critic/vision.py index de317b3..6e7b847 100644 --- a/ginka/critic/vision.py +++ b/ginka/critic/vision.py @@ -10,13 +10,10 @@ class MinamoVisionModel(nn.Module): spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11 nn.LeakyReLU(0.2), - spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9 + spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9 nn.LeakyReLU(0.2), - spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7 - nn.LeakyReLU(0.2), - - spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 5*5 + spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7 nn.LeakyReLU(0.2), ) diff --git a/ginka/dataset.py b/ginka/dataset.py index 417a0e9..886afda 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -142,13 +142,14 @@ class GinkaWGANDataset(Dataset): removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) + _, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5) rand = torch.rand(32, 32, 32, device=target.device) return { "real1": removed1, "masked1": rand, "real2": removed2, - "masked2": torch.zeros_like(target), + "masked2": masked, "real3": removed3, "masked3": torch.zeros_like(target), "tag_cond": tag_cond, diff --git a/ginka/generator/input.py b/ginka/generator/input.py index e8339ac..c6fcef6 100644 --- a/ginka/generator/input.py +++ b/ginka/generator/input.py @@ -2,24 +2,25 @@ import torch import torch.nn as nn from ..common.common import ConvFusionModule from ..common.cond import ConditionInjector +from .unet import GinkaEncoderPath, GinkaDecoderPath class RandomInputHead(nn.Module): def __init__(self): super().__init__() - self.enc = ConvFusionModule(32, 256, 256, 32, 32) + self.enc = GinkaEncoderPath(32, 32) + self.dec = GinkaDecoderPath(32) self.out_conv = nn.Sequential( - nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(128), - nn.ELU(), + nn.AdaptiveMaxPool2d((15, 15)), + nn.Conv2d(32, 64, 3, padding=0), + nn.InstanceNorm2d(64), + nn.GELU(), - nn.AdaptiveMaxPool2d((13, 13)), - nn.Conv2d(128, 32, 1), + nn.Conv2d(64, 32, 1), ) - self.inject = ConditionInjector(256, 256) def forward(self, x, cond): - x = self.enc(x) - x = self.inject(x, cond) + x1, x2, x3, x4 = self.enc(x, cond) + x = self.dec(x1, x2, x3, x4, cond) x = self.out_conv(x) return x @@ -28,15 +29,12 @@ class InputUpsample(nn.Module): super().__init__() self.net = nn.Sequential( ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13), - nn.ELU(), nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26 ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26), - nn.ELU(), nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32 ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32), - nn.ELU(), ) def forward(self, x): # [B, C, 13, 13] @@ -47,18 +45,14 @@ class GinkaInput(nn.Module): def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): super().__init__() self.out_size = out_size - self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1]) self.upsample = InputUpsample(in_ch, in_ch*2, out_ch) - self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1]) - self.inject1 = ConditionInjector(256, in_ch) + self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1]) + self.inject1 = ConditionInjector(256, out_ch) self.inject2 = ConditionInjector(256, out_ch) - self.inject3 = ConditionInjector(256, out_ch) def forward(self, x, cond): - x = self.enc1(x) - x = self.inject1(x, cond) x = self.upsample(x) + x = self.inject1(x, cond) + x = self.enc(x) x = self.inject2(x, cond) - x = self.enc2(x) - x = self.inject3(x, cond) return x diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index b05d917..44b59b5 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -1,12 +1,7 @@ -import math -from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data -from shared.graph import batch_convert_soft_map_to_graph -from shared.constant import VISION_WEIGHT, TOPO_WEIGHT -from ..critic.model import MinamoModel CLASS_NUM = 32 ILLEGAL_MAX_NUM = 30 @@ -156,15 +151,15 @@ def entrance_constraint_loss( ) return total_loss -def input_head_illegal_loss(input_map, allowed_classes=(0, 1)): +def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]): C = input_map.shape[1] - mask = torch.ones(C, device=input_map.device) - mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1 - illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel() - - return illegal_class_penalty + unallowed = get_not_allowed(allowed_classes, include_illegal=True) + illegal = input_map[:, unallowed, :, :] + penalty = torch.sum(illegal) -def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1): + return penalty + +def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]): wall_prob = input_map[:, wall_class] # [B, H, W] wall_ratio = wall_prob.mean() # 计算平均墙体占比 wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚 @@ -241,6 +236,16 @@ def immutable_penalty_loss( return penalty +def modifiable_penalty_loss( + probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] +) -> torch.Tensor: + target_modifiable = input[:, modifiable_classes, :, :] + pred_modifiable = probs[:, modifiable_classes, :, :] + existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0) + penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device)) + + return penalty + def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): not_allowed = get_not_allowed(legal_classes, include_illegal=True) input_mask = pred[:, not_allowed, :, :] @@ -249,43 +254,40 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]): - # weight: 判别器损失,CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失 + def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]): + # weight: + # 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改) + # 2. CE 损失 + # 3. 不可修改类型损失和非法图块损失 + # 4. 图块类型损失 + # 5. 入口存在性损失 + # 6. 多样性损失 + # 7. 密度损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond): # 进行插值 batch_size = real_data.size(0) - epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device) + epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device) interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device) - interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device) # 对图像进行反向传播并计算梯度 interp_data.requires_grad_() - interp_graph.x.requires_grad_() - _, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond) + d_score = critic(interp_data, stage, tag_cond, val_cond) # 计算梯度 - grad_vis = torch.autograd.grad( - outputs=d_vis_score, inputs=interp_data, - grad_outputs=torch.ones_like(d_vis_score), - create_graph=True, retain_graph=True, only_inputs=True - )[0] - grad_topo = torch.autograd.grad( - outputs=d_topo_score, inputs=interp_graph.x, - grad_outputs=torch.ones_like(d_topo_score), + grad = torch.autograd.grad( + outputs=d_score, inputs=interp_data, + grad_outputs=torch.ones_like(d_score), create_graph=True, retain_graph=True, only_inputs=True )[0] # 计算梯度的 L2 范数 - grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1) - grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1) + grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1) # 计算梯度惩罚项 - gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean() - gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean() - gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT + gp_loss = ((grad_norm - 1.0) ** 2).mean() # print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item()) return gp_loss @@ -296,10 +298,8 @@ class WGANGinkaLoss: ) -> tuple[torch.Tensor, torch.Tensor]: """ 判别器损失函数 """ fake_data = F.softmax(fake_data, dim=1) - real_graph = batch_convert_soft_map_to_graph(real_data) - fake_graph = batch_convert_soft_map_to_graph(fake_data) - real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond) - fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond) + real_scores = critic(real_data, stage, tag_cond, val_cond) + fake_scores = critic(fake_data, stage, tag_cond, val_cond) # Wasserstein 距离 d_loss = fake_scores.mean() - real_scores.mean() @@ -312,10 +312,9 @@ class WGANGinkaLoss: def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 生成器损失函数 """ probs_fake = F.softmax(fake, dim=1) - fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) + fake_scores = critic(probs_fake, stage, tag_cond, val_cond) + minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage]) ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) constraint_loss = inner_constraint_loss(probs_fake) @@ -343,9 +342,8 @@ class WGANGinkaLoss: def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor: probs_fake = F.softmax(fake, dim=1) - fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) + fake_scores = critic(probs_fake, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) @@ -370,10 +368,9 @@ class WGANGinkaLoss: def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor: probs_fake = F.softmax(fake, dim=1) - fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) + fake_scores = critic(probs_fake, stage, tag_cond, val_cond) + minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage]) immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) constraint_loss = inner_constraint_loss(probs_fake) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) @@ -395,13 +392,15 @@ class WGANGinkaLoss: return sum(losses) - def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor: + def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor: + probs = F.softmax(map, dim=1) + head_scores = critic(probs, 0, tag_cond, val_cond) probs_a, probs_b = probs.chunk(2, dim=0) losses = [ + torch.mean(head_scores), input_head_illegal_loss(probs), - input_head_wall_loss(probs), - -js_divergence(probs_a, probs_b, softmax=False) * 0.3 + -js_divergence(probs_a, probs_b, softmax=False) * 0.1 ] return sum(losses) diff --git a/ginka/generator/output.py b/ginka/generator/output.py index f9fa6f8..b63cac4 100644 --- a/ginka/generator/output.py +++ b/ginka/generator/output.py @@ -1,22 +1,15 @@ import torch import torch.nn as nn -from ..common.common import GCNBlock, DoubleConvBlock +from ..common.common import ConvFusionModule from ..common.cond import ConditionInjector class StageHead(nn.Module): def __init__(self, in_ch, out_ch, out_size=(13, 13)): super().__init__() - self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch]) - self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32) - self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch]) + self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32) self.pool = nn.Sequential( - nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(in_ch*2), - nn.ELU(), - - nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(in_ch), - nn.ELU(), + ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32), + ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32), nn.AdaptiveMaxPool2d(out_size), nn.Conv2d(in_ch, out_ch, 1) @@ -24,10 +17,7 @@ class StageHead(nn.Module): self.inject = ConditionInjector(256, in_ch) def forward(self, x, cond): - x_cnn = self.cnn_head(x) - x_gcn = self.gcn_head(x) - x = torch.cat([x_cnn, x_gcn], dim=1) - x = self.fusion(x) + x = self.dec(x) x = self.inject(x, cond) x = self.pool(x) return x diff --git a/ginka/generator/unet.py b/ginka/generator/unet.py index 839b2e9..c14db9a 100644 --- a/ginka/generator/unet.py +++ b/ginka/generator/unet.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from shared.attention import ChannelAttention -from ..common.common import GCNBlock, TransformerGCNBlock +from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule from ..common.cond import ConditionInjector class GinkaTransformerEncoder(nn.Module): @@ -37,16 +37,17 @@ class GinkaTransformerEncoder(nn.Module): class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch, attn=True): super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(out_ch), - nn.ELU(), - nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(out_ch), - ) - if attn: - self.conv.append(ChannelAttention(out_ch)) - self.conv.append(nn.ELU()) + self.conv = DoubleConvBlock([in_ch, out_ch, out_ch]) + # self.conv = nn.Sequential( + # nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), + # nn.InstanceNorm2d(out_ch), + # nn.ELU(), + # nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), + # nn.InstanceNorm2d(out_ch), + # ) + # if attn: + # self.conv.append(ChannelAttention(out_ch)) + # self.conv.append(nn.ELU()) def forward(self, x): return self.conv(x) @@ -64,47 +65,24 @@ class FusionModule(nn.Module): class GinkaUNetInput(nn.Module): def __init__(self, in_ch, out_ch, w, h): super().__init__() - self.conv = ConvBlock(in_ch, in_ch) - self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h) - self.fusion = ConvBlock(in_ch*2, out_ch) - self.inject = ConditionInjector(256, out_ch) - - def forward(self, x, cond): - x1 = self.conv(x) - x2 = self.gcn(x) - x = torch.cat([x1, x2], dim=1) - x = self.fusion(x) - x = self.inject(x, cond) - return x - -class GinkaEncoder(nn.Module): - """编码器(下采样)部分""" - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = ConvBlock(in_ch, out_ch) - self.pool = nn.MaxPool2d(2) + self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h) self.inject = ConditionInjector(256, out_ch) def forward(self, x, cond): x = self.conv(x) - x = self.pool(x) x = self.inject(x, cond) return x -class GinkaGCNFusedEncoder(nn.Module): +class GinkaEncoder(nn.Module): def __init__(self, in_ch, out_ch, w, h): super().__init__() - self.conv = ConvBlock(in_ch, out_ch) - self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h) self.pool = nn.MaxPool2d(2) - self.fusion = FusionModule(out_ch*2, out_ch) + self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h) self.inject = ConditionInjector(256, out_ch) def forward(self, x, cond): - x = self.conv(x) x = self.pool(x) - x2 = self.gcn(x) - x = self.fusion(x, x2) + x = self.conv(x) x = self.inject(x, cond) return x @@ -114,42 +92,29 @@ class GinkaUpSample(nn.Module): self.conv = nn.Sequential( nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), nn.InstanceNorm2d(out_ch), - nn.ELU(), + nn.GELU(), + + nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(out_ch), + nn.GELU() ) def forward(self, x): return self.conv(x) class GinkaDecoder(nn.Module): - """解码器(上采样)部分""" - def __init__(self, in_ch, out_ch): - super().__init__() - self.upsample = GinkaUpSample(in_ch, in_ch // 2) - self.conv = ConvBlock(in_ch, out_ch) - self.inject = ConditionInjector(256, out_ch) - - def forward(self, x, feat, cond): - x = self.upsample(x) - x = torch.cat([x, feat], dim=1) - x = self.conv(x) - x = self.inject(x, cond) - return x - -class GinkaGCNFusedDecoder(nn.Module): def __init__(self, in_ch, out_ch, w, h): super().__init__() self.upsample = GinkaUpSample(in_ch, in_ch // 2) - self.conv = ConvBlock(in_ch, out_ch) - self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h) - self.fusion = FusionModule(out_ch*2, out_ch) + self.fusion = nn.Conv2d(in_ch, in_ch, 1) + self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h) self.inject = ConditionInjector(256, out_ch) def forward(self, x, feat, cond): x = self.upsample(x) x = torch.cat([x, feat], dim=1) + x = self.fusion(x) x = self.conv(x) - x2 = self.gcn(x) - x = self.fusion(x, x2) x = self.inject(x, cond) return x @@ -162,58 +127,62 @@ class GinkaBottleneck(nn.Module): # ) # self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4) # self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) - self.conv = ConvBlock(module_ch, module_ch) - self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h) - self.fusion = nn.Conv2d(module_ch*2, module_ch, 1) + self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h) self.inject = ConditionInjector(256, module_ch) def forward(self, x, cond): - B = x.size(0) - # x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] # x1 = self.transformer(x1) # x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] - x1 = self.conv(x) - x2 = self.gcn(x) - - x = torch.cat([x1, x2], dim=1) - x = self.fusion(x) + x = self.conv(x) x = self.inject(x, cond) - return x - -class GinkaUNet(nn.Module): - def __init__(self, in_ch=32, base_ch=64, out_ch=32): - """Ginka Model UNet 部分 - """ + +class GinkaEncoderPath(nn.Module): + def __init__(self, in_ch, base_ch): super().__init__() self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32) - self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) - self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) - self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4) - self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4) - - self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8) - self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16) - self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32) - - self.final = nn.Sequential( - nn.Conv2d(base_ch, out_ch, 1), - nn.InstanceNorm2d(out_ch), - nn.ELU(), - ) + self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16) + self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8) + self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4) def forward(self, x, cond): x1 = self.down1(x, cond) # [B, 64, 32, 32] x2 = self.down2(x1, cond) # [B, 128, 16, 16] x3 = self.down3(x2, cond) # [B, 256, 8, 8] x4 = self.down4(x3, cond) # [B, 512, 4, 4] - x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4] - # 上采样 + return x1, x2, x3, x4 + +class GinkaDecoderPath(nn.Module): + def __init__(self, base_ch): + super().__init__() + self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8) + self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16) + self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32) + + def forward(self, x1, x2, x3, x4, cond): x = self.up1(x4, x3, cond) # [B, 256, 8, 8] x = self.up2(x, x2, cond) # [B, 128, 16, 16] x = self.up3(x, x1, cond) # [B, 64, 32, 32] + return x + +class GinkaUNet(nn.Module): + def __init__(self, in_ch=32, base_ch=32, out_ch=32): + """Ginka Model UNet 部分 + """ + super().__init__() + self.enc = GinkaEncoderPath(in_ch, base_ch) + self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4) + self.dec = GinkaDecoderPath(base_ch) + + self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32) + + def forward(self, x, cond): + x1, x2, x3, x4 = self.enc(x, cond) + x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4] + x = self.dec(x1, x2, x3, x4, cond) + x = self.final(x) # [B, 32, 32, 32] return x diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 9f63e6f..5c700c5 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -6,12 +6,13 @@ import torch import torch.optim as optim import torch.nn.functional as F import cv2 +import numpy as np from torch_geometric.loader import DataLoader from tqdm import tqdm from .generator.model import GinkaModel from .dataset import GinkaWGANDataset from .generator.loss import WGANGinkaLoss -from .critic.model import MinamoModel +from .critic.model import MinamoModel2 from shared.image import matrix_to_image_cv # 标签定义: @@ -105,7 +106,7 @@ def train(): stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 ginka = GinkaModel().to(device) - minamo = MinamoModel().to(device) + minamo = MinamoModel2().to(device) dataset = GinkaWGANDataset(args.train, device) dataset_val = GinkaWGANDataset(args.validate, device) @@ -113,7 +114,7 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) - optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) + optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9)) # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) # scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs) @@ -201,14 +202,24 @@ def train(): fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) + fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) + + if train_stage == 4: + loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond) loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond) loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond) loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond) - dis_avg = (dis1 + dis2 + dis3) / 3.0 - loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 + dis = [dis1, dis2, dis3] + loss_d = [loss_d1, loss_d2, loss_d3] + + if train_stage == 4: + dis.append(dis0) + loss_d.append(loss_d0) + + dis_avg = sum(dis) / len(dis) + loss_d_avg = sum(loss_d) / len(loss_d) # 反向传播 loss_d_avg.backward() @@ -230,7 +241,7 @@ def train(): loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond) loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond) - loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 + loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0 loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3) loss_g.backward() @@ -240,19 +251,16 @@ def train(): elif train_stage == 3 or train_stage == 4: fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4) - - if train_stage == 3: - loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond) - else: - loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond) + + loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond) loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond) loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond) if train_stage == 4: - loss_head = criterion.generator_input_head_loss(x_in) + loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond) loss_head.backward(retain_graph=True) - loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 + loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0 loss_g.backward() optimizer_ginka.step() loss_total_ginka += loss_g.detach() @@ -286,6 +294,8 @@ def train(): }, f"result/wgan/minamo-{epoch + 1}.pth") idx = 0 + gap = 5 + color = (255, 255, 255) # 白色 with torch.no_grad(): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): real1 = batch["real1"].to(device) @@ -301,17 +311,42 @@ def train(): fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) + fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) + x_in = torch.argmax(x_in, dim=1).cpu().numpy() fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy() fake3 = torch.argmax(fake3, dim=1).cpu().numpy() - + masked1 = torch.argmax(masked1, dim=1).cpu().numpy() + masked2 = torch.argmax(masked2, dim=1).cpu().numpy() + masked3 = torch.argmax(masked3, dim=1).cpu().numpy() + for i in range(fake1.shape[0]): - for key, one in enumerate([fake1, fake2, fake3]): - map_matrix = one[i] - image = matrix_to_image_cv(map_matrix, tile_dict) - cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image) + fake1_img = matrix_to_image_cv(fake1[i], tile_dict) + fake2_img = matrix_to_image_cv(fake2[i], tile_dict) + fake3_img = matrix_to_image_cv(fake3[i], tile_dict) + if train_stage == 1 or train_stage == 2: + vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 + hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线 + in1_img = matrix_to_image_cv(masked1[i], tile_dict) + in2_img = matrix_to_image_cv(masked2[i], tile_dict) + in3_img = matrix_to_image_cv(masked3[i], tile_dict) + img = np.block([ + [[in1_img], [vline], [in2_img], [vline], [in3_img]], + [[hline]], + [[fake1_img], [vline], [fake2_img], [vline], [fake3_img]] + ]) + elif train_stage == 3 or train_stage == 4: + vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 + hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线 + in_img = matrix_to_image_cv(x_in[i], tile_dict) + img = np.block([ + [[in_img], [vline], [fake1_img]], + [[hline]], + [[fake2_img], [vline], [fake3_img]] + ]) + + cv2.imwrite(f"result/ginka_img/{idx}.png", img) idx += 1