From 55f09fb37b4ae9a9ac2551348edb134ed4f03f5d Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 7 May 2025 15:38:31 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=94=B9=E8=BF=9B=E8=BE=93=E5=85=A5?= =?UTF-8?q?=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/common/common.py | 34 +++++- ginka/common/cond.py | 13 +-- ginka/generator/input.py | 62 ++++++----- ginka/generator/loss.py | 220 +++++++++------------------------------ ginka/generator/model.py | 8 +- ginka/generator/unet.py | 22 +++- ginka/train_wgan.py | 9 +- 7 files changed, 148 insertions(+), 220 deletions(-) diff --git a/ginka/common/common.py b/ginka/common/common.py index 0121583..873a289 100644 --- a/ginka/common/common.py +++ b/ginka/common/common.py @@ -61,4 +61,36 @@ class GCNBlock(nn.Module): for i in range(B): offset = i * num_nodes_per_batch batch_edge_index.append(edge_index + offset) - return torch.cat(batch_edge_index, dim=1) \ No newline at end of file + return torch.cat(batch_edge_index, dim=1) + +class ConvFusionModule(nn.Module): + def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): + super().__init__() + self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch]) + self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h) + self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch]) + + def forward(self, x): + x1 = self.cnn(x) + x2 = self.gcn(x) + x = torch.cat([x1, x2], dim=1) + x = self.fusion(x) + return x + +class DoubleFCModule(nn.Module): + def __init__(self, in_dim, hidden_dim, out_dim): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ELU(), + + nn.Linear(hidden_dim, out_dim), + nn.LayerNorm(out_dim), + nn.ELU() + ) + + def forward(self, x): + x = self.fc(x) + return x + \ No newline at end of file diff --git a/ginka/common/cond.py b/ginka/common/cond.py index bd6c218..bc972b6 100644 --- a/ginka/common/cond.py +++ b/ginka/common/cond.py @@ -1,19 +1,14 @@ import torch import torch.nn as nn import torch.nn.functional as F +from .common import DoubleFCModule class ConditionEncoder(nn.Module): def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): super().__init__() - self.tag_embed = nn.Linear(tag_dim, hidden_dim) - self.val_embed = nn.Linear(val_dim, hidden_dim) - self.stage_embed = nn.Sequential( - nn.Linear(1, 64), - nn.LayerNorm(64), - nn.ELU(), - - nn.Linear(64, hidden_dim), - ) + self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim) + self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim) + self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, diff --git a/ginka/generator/input.py b/ginka/generator/input.py index da0a37b..3275bc6 100644 --- a/ginka/generator/input.py +++ b/ginka/generator/input.py @@ -1,18 +1,12 @@ import torch import torch.nn as nn -from ..common.common import GCNBlock, DoubleConvBlock +from ..common.common import ConvFusionModule from ..common.cond import ConditionInjector class RandomInputHead(nn.Module): def __init__(self): super().__init__() - self.conv = DoubleConvBlock([32, 64, 128]) - self.gcn = GCNBlock(32, 128, 128, 32, 32) - self.fusion = nn.Sequential( - nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(256), - nn.ELU(), - ) + self.enc = ConvFusionModule(32, 256, 256, 32, 32) self.out_conv = nn.Sequential( nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(128), @@ -24,33 +18,45 @@ class RandomInputHead(nn.Module): self.inject = ConditionInjector(256, 256) def forward(self, x, cond): - x_cnn = self.conv(x) - x_gcn = self.gcn(x) - x = torch.cat([x_cnn, x_gcn], dim=1) - x = self.fusion(x) + x = self.enc(x) x = self.inject(x, cond) x = self.out_conv(x) return x + +class InputUpsample(nn.Module): + def __init__(self, in_ch, hidden_ch=64, out_ch=64): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1), + nn.ELU(), + + nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26 + nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1), + nn.ELU(), + + nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32 + nn.Conv2d(hidden_ch, out_ch, kernel_size=3, padding=1), + nn.ELU(), + ) + + def forward(self, x): # [B, C, 13, 13] + x = self.net(x) # [B, C, 32, 32] + return x class GinkaInput(nn.Module): def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): super().__init__() self.out_size = out_size - self.fc = nn.Sequential( - nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]), - nn.LayerNorm(out_size[0] * out_size[1]), - nn.ELU() - ) - self.conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(out_ch), - nn.ELU() - ) + self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1]) + self.upsample = InputUpsample(in_ch, in_ch*2, out_ch) + self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1]) + self.inject1 = ConditionInjector(256, in_ch) + self.inject2 = ConditionInjector(256, out_ch) - def forward(self, x): - B, C, H, W = x.shape - x = x.view(B, C, H * W) - x = self.fc(x) - x = x.view(B, C, self.out_size[0], self.out_size[1]) - x = self.conv(x) + def forward(self, x, cond): + x = self.enc1(x) + x = self.inject1(x, cond) + x = self.upsample(x) + x = self.enc2(x) + x = self.inject2(x, cond) return x diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 2d13ef9..91046bd 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -11,13 +11,20 @@ from ..critic.model import MinamoModel CLASS_NUM = 32 ILLEGAL_MAX_NUM = 30 -STAGE_ALLOWED = [ +STAGE_CHANGEABLE = [ [], [0, 1, 2, 29, 30], [3, 4, 5, 6, 26, 27, 28], list(range(7, 26)) ] +STAGE_ALLOWED = [ + [], + STAGE_CHANGEABLE[1], + [*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2]], + [*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2], *STAGE_CHANGEABLE[3]] +] + DENSITY_MAP = [ [1, *list(range(3, 30))], [1], @@ -32,6 +39,27 @@ DENSITY_MAP = [ [29, 30] ] +DENSITY_WEIGHTS = [ + 1, + 1.5, + 0.5, + 5, + 4, + 3, + 3, + 3, + 5, + 10, + 20 +] + +DENSITY_STAGE = [ + [], + [1, 2, 10], + [1, 2, 3, 4, 10], + list(range(0, 11)) +] + def get_not_allowed(classes: list[int], include_illegal=False): res = list() for num in range(0, CLASS_NUM): @@ -44,37 +72,6 @@ def get_not_allowed(classes: list[int], include_illegal=False): return res -def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[*list(range(0, 29)), 30]): - """ - 强制地图最外圈像素必须为指定类别(墙或箭头) - - 参数: - pred: 模型输出的概率分布,形状 [B, C, H, W] - allowed_classes: 允许出现在外圈的类别列表 - - 返回: - loss: 标量损失值 - """ - B, C, H, W = pred.shape - - # 创建外圈mask [H, W] - border_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device) - border_mask[0, :] = True # 第一行 - border_mask[-1, :] = True # 最后一行 - border_mask[:, 0] = True # 第一列 - border_mask[:, -1] = True # 最后一列 - - # 提取所有允许和不允许类别的概率和 [B, H, W] - unallowed_probs = pred[:, get_not_allowed(allowed_classes, include_illegal=True), :, :].sum(dim=1) - - # 获取外圈区域允许类别的概率 [B, N_pixels] - border_unallowed = unallowed_probs[:, border_mask] - - target = torch.zeros_like(border_unallowed) - loss_unallowed = F.mse_loss(border_unallowed, target) - - return loss_unallowed - def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))): """限定内部允许出现的图块种类 @@ -159,93 +156,6 @@ def entrance_constraint_loss( ) return total_loss -def adaptive_count_loss( - pred_probs: torch.Tensor, - target_map: torch.Tensor, - class_list: list = list(range(32)), - margin_ratio: float = 0.1, # 降低margin比例以更严格 - zero_margin_scale: float = 0.1, # 减少零类别的margin - lambda_entropy: float = 0.2, # 增大熵约束权重 - lambda_local: float = 0.2, - lambda_max: float = 0, # 新增最大概率约束 - grid_size: int = 4, # 减小局部网格尺寸 - eps: float = 1e-3 -) -> torch.Tensor: - """ - 改进版自适应图块数量约束损失,增强局部匹配和概率确定性 - """ - B, C, H, W = pred_probs.shape - device = pred_probs.device - total_loss = 0.0 - valid_classes = 0 - - # 预计算地图面积 - map_area = math.sqrt(H * W) - - # 动态调整零类别的margin:基于预测中最小的非零概率 - min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() - dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area - - # 计算每个类别的数量损失 - for cls in class_list: - pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测数量 - true_count = target_map[:, cls].sum(dim=(1,2)) # 真实数量 - - zero_mask = (true_count == 0) - dynamic_margin = torch.where( - zero_mask, - dynamic_zero_margin, - margin_ratio * true_count - ) - - safe_true = true_count + eps * zero_mask - abs_error = torch.abs(pred_count - true_count) - rel_error = abs_error / safe_true - - # 调整损失函数形状,远离目标时惩罚更大 - loss_per_class = torch.where( - abs_error <= dynamic_margin, - rel_error ** 2, # 近目标时二次损失 - (rel_error - 0.5 * margin_ratio) ** 2 # 远目标时二次增长 - ) - - # 零类别使用更严格的绝对误差惩罚 - loss_per_class = torch.where( - zero_mask, - F.relu(abs_error - dynamic_zero_margin) ** 2 / map_area, - loss_per_class - ) - - total_loss += loss_per_class.mean() - valid_classes += 1 - - total_loss /= valid_classes # 平均类别损失 - - # 改进的熵约束:每个像素的熵 - def entropy_loss(pred_probs): - entropy_per_pixel = -torch.sum(pred_probs * torch.log(pred_probs + 1e-6), dim=1) - return entropy_per_pixel.mean() # 所有像素的平均熵 - - total_loss += lambda_entropy * entropy_loss(pred_probs) - - # 新增最大概率约束:鼓励每个位置概率尖锐化 - max_probs = pred_probs.max(dim=1)[0] # 每个位置的最大概率 - max_loss = (1 - max_probs).mean() # 鼓励接近1 - total_loss += lambda_max * max_loss - - # 改进局部损失:约束局部区域内的数量 - def local_count_loss(pred_probs, target_probs, grid_size): - grid_area = grid_size ** 2 - # 计算每个grid内的预测数量 - pred_counts = F.avg_pool2d(pred_probs, grid_size, stride=grid_size) * grid_area - target_counts = F.avg_pool2d(target_probs, grid_size, stride=grid_size) * grid_area - # 使用L1损失更鲁棒 - return F.l1_loss(pred_counts, target_counts) - - total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size) - - return total_loss - def input_head_illegal_loss(input_map, allowed_classes=(0, 1)): C = input_map.shape[1] mask = torch.ones(C, device=input_map.device) @@ -261,7 +171,7 @@ def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1): return wall_penalty -def compute_multi_density_loss(probs, target_densities): +def compute_multi_density_loss(probs, target_densities, tile_list): """ pred: [B, C, H, W] target_densities: [B, N] - N 个目标类别密度 @@ -271,53 +181,10 @@ def compute_multi_density_loss(probs, target_densities): for i, classes in enumerate(DENSITY_MAP): class_map = probs[:, classes, :, :] pred_density = torch.mean(class_map, dim=(1, 2, 3)) - loss = F.mse_loss(pred_density, target_densities[:, i]) - losses.append(loss) + if i in tile_list: + loss = F.mse_loss(pred_density, target_densities[:, i]) + losses.append(loss * DENSITY_WEIGHTS[i]) return sum(losses) - -class GinkaLoss(nn.Module): - def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]): - """Ginka Model 损失函数部分 - - Args: - weight (list, optional): 每一个损失函数的权重,从第 0 项开始,依次是: - 1. Minamo 相似度损失 - 2. 图块种类损失,要求内部不出现箭头,外圈只出现箭头和墙壁,不允许出现非法图块 - 3. 入口间距及存在性损失 - 4. 怪物、道具、门数量损失 - """ - super().__init__() - self.weight = weight - self.minamo = minamo - - def forward(self, pred, target, target_vision_feat, target_topo_feat): - # 地图结构损失 - class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred) - entrance_loss = entrance_constraint_loss(pred) - count_loss = adaptive_count_loss(pred, target) - - # 使用 Minamo Model 计算相似度 - graph = batch_convert_soft_map_to_graph(pred) - pred_vision_feat, pred_topo_feat = self.minamo(pred, graph) - - vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=1) - topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=1) - minamo_sim = 0 * vision_sim + 1 * topo_sim - # tqdm.write(f"{vision_sim.mean().item():.12f}, {topo_sim.mean().item():.12f}") - minamo_loss = (1.0 - minamo_sim).mean() - - tqdm.write( - f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}" - ) - - losses = [ - minamo_loss * self.weight[0] * 4, - class_loss * self.weight[1], - entrance_loss * self.weight[2], - count_loss * self.weight[3] - ] - - return sum(losses) # 对图像数据进行插值 def interpolate_data(real_data, fake_data, epsilon): @@ -374,9 +241,16 @@ def immutable_penalty_loss( return penalty +def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): + not_allowed = get_not_allowed(legal_classes, include_illegal=True) + input_mask = pred[:, not_allowed, :, :] + target = torch.zeros_like(input_mask) + penalty = F.cross_entropy(input_mask, target) + return penalty + class WGANGinkaLoss: def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]): - # weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失 + # weight: 判别器损失,CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight @@ -443,9 +317,9 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 - immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) + immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) constraint_loss = inner_constraint_loss(probs_fake) - density_loss = compute_multi_density_loss(probs_fake, val_cond) + density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) fake_a, fake_b = fake.chunk(2, dim=0) @@ -473,13 +347,15 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) + illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) - density_loss = compute_multi_density_loss(probs_fake, val_cond) + density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], + illegal_loss * self.weight[2], constraint_loss * self.weight[3], -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], density_loss * self.weight[6], @@ -498,9 +374,9 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) - immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) + immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) constraint_loss = inner_constraint_loss(probs_fake) - density_loss = compute_multi_density_loss(probs_fake, val_cond) + density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) fake_a, fake_b = fake.chunk(2, dim=0) diff --git a/ginka/generator/model.py b/ginka/generator/model.py index acb1826..30991a1 100644 --- a/ginka/generator/model.py +++ b/ginka/generator/model.py @@ -16,8 +16,8 @@ class GinkaModel(nn.Module): super().__init__() self.head = RandomInputHead() self.cond = ConditionEncoder(64, 16, 256, 256) - self.input = GinkaInput(32, 32, (13, 13), (32, 32)) - self.unet = GinkaUNet(32, base_ch, base_ch) + self.input = GinkaInput(32, 64, (13, 13), (32, 32)) + self.unet = GinkaUNet(64, base_ch, base_ch) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) def forward(self, x, stage, tag_cond, val_cond, random=False): @@ -28,7 +28,7 @@ class GinkaModel(nn.Module): x_in = F.softmax(self.head(x, cond), dim=1) else: x_in = x - x = self.input(x_in) + x = self.input(x_in, cond) x = self.unet(x, cond) x = self.output(x, stage, cond) return x, x_in @@ -51,7 +51,7 @@ if __name__ == "__main__": print(f"输入形状: feat={input.shape}") print(f"输出形状: output={output.shape}") - print(f"Head parameters: {sum(p.numel() for p in model.head.parameters())}") + print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}") print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") diff --git a/ginka/generator/unet.py b/ginka/generator/unet.py index f7bea1d..5fd2647 100644 --- a/ginka/generator/unet.py +++ b/ginka/generator/unet.py @@ -60,6 +60,22 @@ class FusionModule(nn.Module): x = torch.cat([x1, x2], dim=1) x = self.conv(x) return x + +class GinkaUNetInput(nn.Module): + def __init__(self, in_ch, out_ch, w, h): + super().__init__() + self.conv = ConvBlock(in_ch, in_ch) + self.gcn = GCNBlock(in_ch, in_ch*2, in_ch, w, h) + self.fusion = ConvBlock(in_ch*2, out_ch) + self.inject = ConditionInjector(256, out_ch) + + def forward(self, x, cond): + x1 = self.conv(x) + x2 = self.gcn(x) + x = torch.cat([x1, x2], dim=1) + x = self.fusion(x) + x = self.inject(x, cond) + return x class GinkaEncoder(nn.Module): """编码器(下采样)部分""" @@ -142,7 +158,7 @@ class GinkaBottleneck(nn.Module): super().__init__() self.transformer = GinkaTransformerEncoder( in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, - token_size=16, ff_dim=1024, num_layers=4 + token_size=16, ff_dim=1024, num_layers=6 ) self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4) self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) @@ -167,7 +183,7 @@ class GinkaUNet(nn.Module): """Ginka Model UNet 部分 """ super().__init__() - self.down1 = ConvBlock(in_ch, base_ch) + self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4) @@ -184,7 +200,7 @@ class GinkaUNet(nn.Module): ) def forward(self, x, cond): - x1 = self.down1(x) # [B, 64, 32, 32] + x1 = self.down1(x, cond) # [B, 64, 32, 32] x2 = self.down2(x1, cond) # [B, 128, 16, 16] x3 = self.down3(x2, cond) # [B, 256, 8, 8] x4 = self.down4(x3, cond) # [B, 512, 4, 4] diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 14e1e80..2a6e375 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -325,11 +325,11 @@ def train(): low_loss_epochs = 0 if train_stage >= 2: - if stage_epoch % 5 == 1: + if (epoch + 1) % 5 == 1: train_stage = 3 - elif stage_epoch % 5 == 3: + elif (epoch + 1) % 5 == 3: train_stage = 4 - elif stage_epoch % 5 == 0: + elif (epoch + 1) % 5 == 0: train_stage = 2 if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: @@ -350,6 +350,9 @@ def train(): else: g_steps = 1 + if avg_loss_ginka > 0: + g_steps += int(max(avg_loss_ginka * 5, 0)) + if avg_loss_minamo > 0: c_steps = int(min(5 + avg_loss_minamo * 5, 15)) else: