perf: 模型微调

This commit is contained in:
unanmed 2025-05-11 23:50:08 +08:00
parent 21b693ec21
commit fa48863946
11 changed files with 393 additions and 241 deletions

View File

@ -19,11 +19,11 @@ class DoubleConvBlock(nn.Module):
self.cnn = nn.Sequential(
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[1]),
nn.ELU(),
nn.GELU(),
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[2]),
nn.ELU(),
nn.GELU(),
)
def forward(self, x):
@ -57,11 +57,11 @@ class GCNBlock(nn.Module):
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = F.gelu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
x = F.gelu(self.norm2(x))
x = self.conv3(x, edge_index)
x = F.elu(self.norm3(x))
x = F.gelu(self.norm3(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
@ -92,9 +92,9 @@ class TransformerGCNBlock(nn.Module):
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = F.gelu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
x = F.gelu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
@ -104,8 +104,8 @@ class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch])
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
def forward(self, x):
x1 = self.cnn(x)
@ -120,11 +120,11 @@ class DoubleFCModule(nn.Module):
self.fc = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ELU(),
nn.GELU(),
nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim),
nn.ELU()
nn.GELU()
)
def forward(self, x):

View File

@ -6,22 +6,22 @@ from .common import DoubleFCModule
class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__()
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim)
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=6
num_layers=4
)
self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim*2),
nn.LayerNorm(hidden_dim*2),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim*2, out_dim)
nn.Linear(hidden_dim, out_dim)
)
def forward(self, tag, val, stage):
@ -38,18 +38,10 @@ class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.gamma_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
nn.Linear(cond_dim, out_dim)
)
self.beta_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
nn.Linear(cond_dim, out_dim)
)
def forward(self, x, cond):

View File

@ -2,22 +2,138 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv
from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv
from torch_geometric.utils import grid
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.graph import batch_convert_soft_map_to_graph
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from ..common.cond import ConditionEncoder
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
def batch_edge_index(B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class DoubleConvBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')),
nn.GELU(),
spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')),
nn.GELU(),
)
def forward(self, x):
x = self.cnn(x)
return x
class TransformerGCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
device = x.device
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
x = self.conv1(x, edge_index)
x = F.gelu(x)
x = self.conv2(x, edge_index)
x = F.gelu(x)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
def forward(self, x):
x1 = self.cnn(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
return x
class DoubleFCModule(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, out_dim)),
nn.GELU()
)
def forward(self, x):
x = self.fc(x)
return x
class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__()
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=4
)
self.fusion = nn.Sequential(
spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, out_dim))
)
def forward(self, tag, val, stage):
tag = self.tag_embed(tag)
val = self.val_embed(val)
stage = self.stage_embed(stage)
feat = torch.stack([tag, val, stage], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)
return feat
class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.gamma_layer = nn.Sequential(
spectral_norm(nn.Linear(cond_dim, out_dim))
)
self.beta_layer = nn.Sequential(
spectral_norm(nn.Linear(cond_dim, out_dim))
)
def forward(self, x, cond):
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
return x * gamma + beta
class CNNHead(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
nn.LeakyReLU(0.2),
nn.GELU(),
nn.AdaptiveMaxPool2d((2, 2))
)
@ -46,7 +162,7 @@ class GCNHead(nn.Module):
def forward(self, x, graph, cond):
x = self.gcn(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = F.gelu(x)
x = global_max_pool(x, graph.batch)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
@ -91,6 +207,65 @@ class MinamoModel(nn.Module):
raise RuntimeError("Unknown critic stage.")
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
return score, vision_score, topo_score
class MinamoHead2(nn.Module):
def __init__(self, in_ch, hidden_ch):
super().__init__()
self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13)
self.pool = nn.AdaptiveMaxPool2d(1)
self.proj = spectral_norm(nn.Linear(256, hidden_ch))
self.fc = spectral_norm(nn.Linear(hidden_ch, 1))
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x = x.squeeze(3).squeeze(2)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class MinamoModel2(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.cond = ConditionEncoder(64, 16, 256, 256)
self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13)
self.conv2 = ConvFusionModule(128, 256, 256, 13, 13)
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
self.head1 = MinamoHead2(256, 256)
self.head2 = MinamoHead2(256, 256)
self.head3 = MinamoHead2(256, 256)
self.inject1 = ConditionInjector(256, 128)
self.inject2 = ConditionInjector(256, 256)
self.inject3 = ConditionInjector(256, 256)
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
x = self.conv1(x)
x = self.inject1(x, cond)
x = self.conv2(x)
x = self.inject2(x, cond)
x = self.conv3(x)
x = self.inject3(x, cond)
if stage == 0:
score = self.head0(x, cond)
elif stage == 1:
score = self.head1(x, cond)
elif stage == 2:
score = self.head2(x, cond)
elif stage == 3:
score = self.head3(x, cond)
else:
raise RuntimeError("Unknown critic stage.")
return score
# 检查显存占用
if __name__ == "__main__":
@ -99,19 +274,19 @@ if __name__ == "__main__":
val = torch.rand(1, 16).cuda()
# 初始化模型
model = MinamoModel().cuda()
model = MinamoModel2().cuda()
print_memory("初始化后")
# 前向传播
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val)
output = model(input, 1, tag, val)
print_memory("前向传播后")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
# print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
# print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -2,12 +2,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import GATConv
from torch_geometric.nn import GATConv, TransformerConv
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512
):
super().__init__()
# 传入 softmax 概率值,直接映射
@ -16,9 +16,9 @@ class MinamoTopoModel(nn.Module):
nn.LeakyReLU(0.2)
)
# 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim, heads=8)
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8)
self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1)
def forward(self, graph: Data):
x = self.input_proj(graph.x)

View File

@ -10,13 +10,10 @@ class MinamoVisionModel(nn.Module):
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 5*5
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7
nn.LeakyReLU(0.2),
)

View File

@ -142,13 +142,14 @@ class GinkaWGANDataset(Dataset):
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
_, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"real1": removed1,
"masked1": rand,
"real2": removed2,
"masked2": torch.zeros_like(target),
"masked2": masked,
"real3": removed3,
"masked3": torch.zeros_like(target),
"tag_cond": tag_cond,

View File

@ -2,24 +2,25 @@ import torch
import torch.nn as nn
from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector
from .unet import GinkaEncoderPath, GinkaDecoderPath
class RandomInputHead(nn.Module):
def __init__(self):
super().__init__()
self.enc = ConvFusionModule(32, 256, 256, 32, 32)
self.enc = GinkaEncoderPath(32, 32)
self.dec = GinkaDecoderPath(32)
self.out_conv = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(128),
nn.ELU(),
nn.AdaptiveMaxPool2d((15, 15)),
nn.Conv2d(32, 64, 3, padding=0),
nn.InstanceNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((13, 13)),
nn.Conv2d(128, 32, 1),
nn.Conv2d(64, 32, 1),
)
self.inject = ConditionInjector(256, 256)
def forward(self, x, cond):
x = self.enc(x)
x = self.inject(x, cond)
x1, x2, x3, x4 = self.enc(x, cond)
x = self.dec(x1, x2, x3, x4, cond)
x = self.out_conv(x)
return x
@ -28,15 +29,12 @@ class InputUpsample(nn.Module):
super().__init__()
self.net = nn.Sequential(
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
nn.ELU(),
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
nn.ELU(),
)
def forward(self, x): # [B, C, 13, 13]
@ -47,18 +45,14 @@ class GinkaInput(nn.Module):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__()
self.out_size = out_size
self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1])
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1])
self.inject1 = ConditionInjector(256, in_ch)
self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1])
self.inject1 = ConditionInjector(256, out_ch)
self.inject2 = ConditionInjector(256, out_ch)
self.inject3 = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.enc1(x)
x = self.inject1(x, cond)
x = self.upsample(x)
x = self.inject1(x, cond)
x = self.enc(x)
x = self.inject2(x, cond)
x = self.enc2(x)
x = self.inject3(x, cond)
return x

View File

@ -1,12 +1,7 @@
import math
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from shared.graph import batch_convert_soft_map_to_graph
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from ..critic.model import MinamoModel
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 30
@ -156,15 +151,15 @@ def entrance_constraint_loss(
)
return total_loss
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)):
def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
C = input_map.shape[1]
mask = torch.ones(C, device=input_map.device)
mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1
illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel()
return illegal_class_penalty
unallowed = get_not_allowed(allowed_classes, include_illegal=True)
illegal = input_map[:, unallowed, :, :]
penalty = torch.sum(illegal)
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
return penalty
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]):
wall_prob = input_map[:, wall_class] # [B, H, W]
wall_ratio = wall_prob.mean() # 计算平均墙体占比
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
@ -241,6 +236,16 @@ def immutable_penalty_loss(
return penalty
def modifiable_penalty_loss(
probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
) -> torch.Tensor:
target_modifiable = input[:, modifiable_classes, :, :]
pred_modifiable = probs[:, modifiable_classes, :, :]
existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0)
penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device))
return penalty
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
not_allowed = get_not_allowed(legal_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :]
@ -249,43 +254,40 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]):
# weight: 判别器损失CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失
def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]):
# weight:
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
# 2. CE 损失
# 3. 不可修改类型损失和非法图块损失
# 4. 图块类型损失
# 5. 入口存在性损失
# 6. 多样性损失
# 7. 密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
# 进行插值
batch_size = real_data.size(0)
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device)
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device)
# 对图像进行反向传播并计算梯度
interp_data.requires_grad_()
interp_graph.x.requires_grad_()
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond)
d_score = critic(interp_data, stage, tag_cond, val_cond)
# 计算梯度
grad_vis = torch.autograd.grad(
outputs=d_vis_score, inputs=interp_data,
grad_outputs=torch.ones_like(d_vis_score),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_topo = torch.autograd.grad(
outputs=d_topo_score, inputs=interp_graph.x,
grad_outputs=torch.ones_like(d_topo_score),
grad = torch.autograd.grad(
outputs=d_score, inputs=interp_data,
grad_outputs=torch.ones_like(d_score),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
# 计算梯度的 L2 范数
grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1)
grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1)
grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1)
# 计算梯度惩罚项
gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean()
gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean()
gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT
gp_loss = ((grad_norm - 1.0) ** 2).mean()
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
return gp_loss
@ -296,10 +298,8 @@ class WGANGinkaLoss:
) -> tuple[torch.Tensor, torch.Tensor]:
""" 判别器损失函数 """
fake_data = F.softmax(fake_data, dim=1)
real_graph = batch_convert_soft_map_to_graph(real_data)
fake_graph = batch_convert_soft_map_to_graph(fake_data)
real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond)
fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond)
real_scores = critic(real_data, stage, tag_cond, val_cond)
fake_scores = critic(fake_data, stage, tag_cond, val_cond)
# Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean()
@ -312,10 +312,9 @@ class WGANGinkaLoss:
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake)
@ -343,9 +342,8 @@ class WGANGinkaLoss:
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)
@ -370,10 +368,9 @@ class WGANGinkaLoss:
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
@ -395,13 +392,15 @@ class WGANGinkaLoss:
return sum(losses)
def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor:
def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
probs = F.softmax(map, dim=1)
head_scores = critic(probs, 0, tag_cond, val_cond)
probs_a, probs_b = probs.chunk(2, dim=0)
losses = [
torch.mean(head_scores),
input_head_illegal_loss(probs),
input_head_wall_loss(probs),
-js_divergence(probs_a, probs_b, softmax=False) * 0.3
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
]
return sum(losses)

View File

@ -1,22 +1,15 @@
import torch
import torch.nn as nn
from ..common.common import GCNBlock, DoubleConvBlock
from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector
class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
super().__init__()
self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch])
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32)
self.pool = nn.Sequential(
nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch*2),
nn.ELU(),
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32),
ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1)
@ -24,10 +17,7 @@ class StageHead(nn.Module):
self.inject = ConditionInjector(256, in_ch)
def forward(self, x, cond):
x_cnn = self.cnn_head(x)
x_gcn = self.gcn_head(x)
x = torch.cat([x_cnn, x_gcn], dim=1)
x = self.fusion(x)
x = self.dec(x)
x = self.inject(x, cond)
x = self.pool(x)
return x

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.attention import ChannelAttention
from ..common.common import GCNBlock, TransformerGCNBlock
from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
from ..common.cond import ConditionInjector
class GinkaTransformerEncoder(nn.Module):
@ -37,16 +37,17 @@ class GinkaTransformerEncoder(nn.Module):
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, attn=True):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
)
if attn:
self.conv.append(ChannelAttention(out_ch))
self.conv.append(nn.ELU())
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
# self.conv = nn.Sequential(
# nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
# nn.InstanceNorm2d(out_ch),
# nn.ELU(),
# nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
# nn.InstanceNorm2d(out_ch),
# )
# if attn:
# self.conv.append(ChannelAttention(out_ch))
# self.conv.append(nn.ELU())
def forward(self, x):
return self.conv(x)
@ -64,47 +65,24 @@ class FusionModule(nn.Module):
class GinkaUNetInput(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, in_ch)
self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h)
self.fusion = ConvBlock(in_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x1 = self.conv(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x = self.inject(x, cond)
return x
class GinkaGCNFusedEncoder(nn.Module):
class GinkaEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.pool = nn.MaxPool2d(2)
self.fusion = FusionModule(out_ch*2, out_ch)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
x = self.conv(x)
x = self.inject(x, cond)
return x
@ -114,42 +92,29 @@ class GinkaUpSample(nn.Module):
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.GELU()
)
def forward(self, x):
return self.conv(x)
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaGCNFusedDecoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.fusion = FusionModule(out_ch*2, out_ch)
self.fusion = nn.Conv2d(in_ch, in_ch, 1)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.fusion(x)
x = self.conv(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
x = self.inject(x, cond)
return x
@ -162,58 +127,62 @@ class GinkaBottleneck(nn.Module):
# )
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
self.conv = ConvBlock(module_ch, module_ch)
self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h)
self.fusion = nn.Conv2d(module_ch*2, module_ch, 1)
self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
self.inject = ConditionInjector(256, module_ch)
def forward(self, x, cond):
B = x.size(0)
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
# x1 = self.transformer(x1)
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x1 = self.conv(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, base_ch=64, out_ch=32):
"""Ginka Model UNet 部分
"""
class GinkaEncoderPath(nn.Module):
def __init__(self, in_ch, base_ch):
super().__init__()
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
)
self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
def forward(self, x, cond):
x1 = self.down1(x, cond) # [B, 64, 32, 32]
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
# 上采样
return x1, x2, x3, x4
class GinkaDecoderPath(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
def forward(self, x1, x2, x3, x4, cond):
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.enc = GinkaEncoderPath(in_ch, base_ch)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.dec = GinkaDecoderPath(base_ch)
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
def forward(self, x, cond):
x1, x2, x3, x4 = self.enc(x, cond)
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
x = self.dec(x1, x2, x3, x4, cond)
x = self.final(x) # [B, 32, 32, 32]
return x

View File

@ -6,12 +6,13 @@ import torch
import torch.optim as optim
import torch.nn.functional as F
import cv2
import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .generator.model import GinkaModel
from .dataset import GinkaWGANDataset
from .generator.loss import WGANGinkaLoss
from .critic.model import MinamoModel
from .critic.model import MinamoModel2
from shared.image import matrix_to_image_cv
# 标签定义:
@ -105,7 +106,7 @@ def train():
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
ginka = GinkaModel().to(device)
minamo = MinamoModel().to(device)
minamo = MinamoModel2().to(device)
dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device)
@ -113,7 +114,7 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
@ -201,14 +202,24 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
if train_stage == 4:
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
dis = [dis1, dis2, dis3]
loss_d = [loss_d1, loss_d2, loss_d3]
if train_stage == 4:
dis.append(dis0)
loss_d.append(loss_d0)
dis_avg = sum(dis) / len(dis)
loss_d_avg = sum(loss_d) / len(loss_d)
# 反向传播
loss_d_avg.backward()
@ -230,7 +241,7 @@ def train():
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
loss_g.backward()
@ -240,19 +251,16 @@ def train():
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
if train_stage == 4:
loss_head = criterion.generator_input_head_loss(x_in)
loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond)
loss_head.backward(retain_graph=True)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
@ -286,6 +294,8 @@ def train():
}, f"result/wgan/minamo-{epoch + 1}.pth")
idx = 0
gap = 5
color = (255, 255, 255) # 白色
with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1 = batch["real1"].to(device)
@ -301,17 +311,42 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
x_in = torch.argmax(x_in, dim=1).cpu().numpy()
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
masked1 = torch.argmax(masked1, dim=1).cpu().numpy()
masked2 = torch.argmax(masked2, dim=1).cpu().numpy()
masked3 = torch.argmax(masked3, dim=1).cpu().numpy()
for i in range(fake1.shape[0]):
for key, one in enumerate([fake1, fake2, fake3]):
map_matrix = one[i]
image = matrix_to_image_cv(map_matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
fake1_img = matrix_to_image_cv(fake1[i], tile_dict)
fake2_img = matrix_to_image_cv(fake2[i], tile_dict)
fake3_img = matrix_to_image_cv(fake3[i], tile_dict)
if train_stage == 1 or train_stage == 2:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线
in1_img = matrix_to_image_cv(masked1[i], tile_dict)
in2_img = matrix_to_image_cv(masked2[i], tile_dict)
in3_img = matrix_to_image_cv(masked3[i], tile_dict)
img = np.block([
[[in1_img], [vline], [in2_img], [vline], [in3_img]],
[[hline]],
[[fake1_img], [vline], [fake2_img], [vline], [fake3_img]]
])
elif train_stage == 3 or train_stage == 4:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
in_img = matrix_to_image_cv(x_in[i], tile_dict)
img = np.block([
[[in_img], [vline], [fake1_img]],
[[hline]],
[[fake2_img], [vline], [fake3_img]]
])
cv2.imwrite(f"result/ginka_img/{idx}.png", img)
idx += 1