perf: 模型微调

This commit is contained in:
unanmed 2025-05-11 23:50:08 +08:00
parent 21b693ec21
commit fa48863946
11 changed files with 393 additions and 241 deletions

View File

@ -19,11 +19,11 @@ class DoubleConvBlock(nn.Module):
self.cnn = nn.Sequential( self.cnn = nn.Sequential(
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'), nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[1]), nn.InstanceNorm2d(feats[1]),
nn.ELU(), nn.GELU(),
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'), nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[2]), nn.InstanceNorm2d(feats[2]),
nn.ELU(), nn.GELU(),
) )
def forward(self, x): def forward(self, x):
@ -57,11 +57,11 @@ class GCNBlock(nn.Module):
# GCN forward # GCN forward
x = self.conv1(x, edge_index) x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x)) x = F.gelu(self.norm1(x))
x = self.conv2(x, edge_index) x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x)) x = F.gelu(self.norm2(x))
x = self.conv3(x, edge_index) x = self.conv3(x, edge_index)
x = F.elu(self.norm3(x)) x = F.gelu(self.norm3(x))
# Reshape back to [B, C, H, W] # Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2) x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
@ -92,9 +92,9 @@ class TransformerGCNBlock(nn.Module):
# GCN forward # GCN forward
x = self.conv1(x, edge_index) x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x)) x = F.gelu(self.norm1(x))
x = self.conv2(x, edge_index) x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x)) x = F.gelu(self.norm2(x))
# Reshape back to [B, C, H, W] # Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2) x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
@ -104,8 +104,8 @@ class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__() super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch]) self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h) self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch]) self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
def forward(self, x): def forward(self, x):
x1 = self.cnn(x) x1 = self.cnn(x)
@ -120,11 +120,11 @@ class DoubleFCModule(nn.Module):
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(in_dim, hidden_dim), nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim), nn.LayerNorm(hidden_dim),
nn.ELU(), nn.GELU(),
nn.Linear(hidden_dim, out_dim), nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim), nn.LayerNorm(out_dim),
nn.ELU() nn.GELU()
) )
def forward(self, x): def forward(self, x):

View File

@ -6,22 +6,22 @@ from .common import DoubleFCModule
class ConditionEncoder(nn.Module): class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__() super().__init__()
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim) self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim) self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim) self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
self.encoder = nn.TransformerEncoder( self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer( nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True batch_first=True
), ),
num_layers=6 num_layers=4
) )
self.fusion = nn.Sequential( self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim*2), nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim*2), nn.LayerNorm(hidden_dim),
nn.ELU(), nn.GELU(),
nn.Linear(hidden_dim*2, out_dim) nn.Linear(hidden_dim, out_dim)
) )
def forward(self, tag, val, stage): def forward(self, tag, val, stage):
@ -38,18 +38,10 @@ class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim): def __init__(self, cond_dim, out_dim):
super().__init__() super().__init__()
self.gamma_layer = nn.Sequential( self.gamma_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2), nn.Linear(cond_dim, out_dim)
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
) )
self.beta_layer = nn.Sequential( self.beta_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2), nn.Linear(cond_dim, out_dim)
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
) )
def forward(self, x, cond): def forward(self, x, cond):

View File

@ -2,22 +2,138 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import spectral_norm from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv
from torch_geometric.utils import grid
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.graph import batch_convert_soft_map_to_graph
from .vision import MinamoVisionModel from .vision import MinamoVisionModel
from .topo import MinamoTopoModel from .topo import MinamoTopoModel
from ..common.cond import ConditionEncoder
def print_memory(tag=""): def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
def batch_edge_index(B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class DoubleConvBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')),
nn.GELU(),
spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')),
nn.GELU(),
)
def forward(self, x):
x = self.cnn(x)
return x
class TransformerGCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
device = x.device
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
x = self.conv1(x, edge_index)
x = F.gelu(x)
x = self.conv2(x, edge_index)
x = F.gelu(x)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
def forward(self, x):
x1 = self.cnn(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
return x
class DoubleFCModule(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, out_dim)),
nn.GELU()
)
def forward(self, x):
x = self.fc(x)
return x
class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__()
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=4
)
self.fusion = nn.Sequential(
spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, out_dim))
)
def forward(self, tag, val, stage):
tag = self.tag_embed(tag)
val = self.val_embed(val)
stage = self.stage_embed(stage)
feat = torch.stack([tag, val, stage], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)
return feat
class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.gamma_layer = nn.Sequential(
spectral_norm(nn.Linear(cond_dim, out_dim))
)
self.beta_layer = nn.Sequential(
spectral_norm(nn.Linear(cond_dim, out_dim))
)
def forward(self, x, cond):
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
return x * gamma + beta
class CNNHead(nn.Module): class CNNHead(nn.Module):
def __init__(self, in_ch): def __init__(self, in_ch):
super().__init__() super().__init__()
self.cnn = nn.Sequential( self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)), spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
nn.LeakyReLU(0.2), nn.GELU(),
nn.AdaptiveMaxPool2d((2, 2)) nn.AdaptiveMaxPool2d((2, 2))
) )
@ -46,7 +162,7 @@ class GCNHead(nn.Module):
def forward(self, x, graph, cond): def forward(self, x, graph, cond):
x = self.gcn(x, graph.edge_index) x = self.gcn(x, graph.edge_index)
x = F.leaky_relu(x, 0.2) x = F.gelu(x)
x = global_max_pool(x, graph.batch) x = global_max_pool(x, graph.batch)
cond = self.proj(cond) cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True) proj = torch.sum(x * cond, dim=1, keepdim=True)
@ -91,6 +207,65 @@ class MinamoModel(nn.Module):
raise RuntimeError("Unknown critic stage.") raise RuntimeError("Unknown critic stage.")
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
return score, vision_score, topo_score return score, vision_score, topo_score
class MinamoHead2(nn.Module):
def __init__(self, in_ch, hidden_ch):
super().__init__()
self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13)
self.pool = nn.AdaptiveMaxPool2d(1)
self.proj = spectral_norm(nn.Linear(256, hidden_ch))
self.fc = spectral_norm(nn.Linear(hidden_ch, 1))
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x = x.squeeze(3).squeeze(2)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class MinamoModel2(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.cond = ConditionEncoder(64, 16, 256, 256)
self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13)
self.conv2 = ConvFusionModule(128, 256, 256, 13, 13)
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
self.head1 = MinamoHead2(256, 256)
self.head2 = MinamoHead2(256, 256)
self.head3 = MinamoHead2(256, 256)
self.inject1 = ConditionInjector(256, 128)
self.inject2 = ConditionInjector(256, 256)
self.inject3 = ConditionInjector(256, 256)
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
x = self.conv1(x)
x = self.inject1(x, cond)
x = self.conv2(x)
x = self.inject2(x, cond)
x = self.conv3(x)
x = self.inject3(x, cond)
if stage == 0:
score = self.head0(x, cond)
elif stage == 1:
score = self.head1(x, cond)
elif stage == 2:
score = self.head2(x, cond)
elif stage == 3:
score = self.head3(x, cond)
else:
raise RuntimeError("Unknown critic stage.")
return score
# 检查显存占用 # 检查显存占用
if __name__ == "__main__": if __name__ == "__main__":
@ -99,19 +274,19 @@ if __name__ == "__main__":
val = torch.rand(1, 16).cuda() val = torch.rand(1, 16).cuda()
# 初始化模型 # 初始化模型
model = MinamoModel().cuda() model = MinamoModel2().cuda()
print_memory("初始化后") print_memory("初始化后")
# 前向传播 # 前向传播
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val) output = model(input, 1, tag, val)
print_memory("前向传播后") print_memory("前向传播后")
print(f"输入形状: feat={input.shape}") print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}") print(f"输出形状: output={output.shape}")
# print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
# print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}") print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -2,12 +2,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import spectral_norm from torch.nn.utils import spectral_norm
from torch_geometric.nn import GATConv from torch_geometric.nn import GATConv, TransformerConv
from torch_geometric.data import Data from torch_geometric.data import Data
class MinamoTopoModel(nn.Module): class MinamoTopoModel(nn.Module):
def __init__( def __init__(
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512 self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512
): ):
super().__init__() super().__init__()
# 传入 softmax 概率值,直接映射 # 传入 softmax 概率值,直接映射
@ -16,9 +16,9 @@ class MinamoTopoModel(nn.Module):
nn.LeakyReLU(0.2) nn.LeakyReLU(0.2)
) )
# 图卷积层 # 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim, heads=8) self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8)
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8) self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1) self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1)
def forward(self, graph: Data): def forward(self, graph: Data):
x = self.input_proj(graph.x) x = self.input_proj(graph.x)

View File

@ -10,13 +10,10 @@ class MinamoVisionModel(nn.Module):
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11 spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9 spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7 spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 5*5
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
) )

View File

@ -142,13 +142,14 @@ class GinkaWGANDataset(Dataset):
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
_, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5)
rand = torch.rand(32, 32, 32, device=target.device) rand = torch.rand(32, 32, 32, device=target.device)
return { return {
"real1": removed1, "real1": removed1,
"masked1": rand, "masked1": rand,
"real2": removed2, "real2": removed2,
"masked2": torch.zeros_like(target), "masked2": masked,
"real3": removed3, "real3": removed3,
"masked3": torch.zeros_like(target), "masked3": torch.zeros_like(target),
"tag_cond": tag_cond, "tag_cond": tag_cond,

View File

@ -2,24 +2,25 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..common.common import ConvFusionModule from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector from ..common.cond import ConditionInjector
from .unet import GinkaEncoderPath, GinkaDecoderPath
class RandomInputHead(nn.Module): class RandomInputHead(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.enc = ConvFusionModule(32, 256, 256, 32, 32) self.enc = GinkaEncoderPath(32, 32)
self.dec = GinkaDecoderPath(32)
self.out_conv = nn.Sequential( self.out_conv = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'), nn.AdaptiveMaxPool2d((15, 15)),
nn.InstanceNorm2d(128), nn.Conv2d(32, 64, 3, padding=0),
nn.ELU(), nn.InstanceNorm2d(64),
nn.GELU(),
nn.AdaptiveMaxPool2d((13, 13)), nn.Conv2d(64, 32, 1),
nn.Conv2d(128, 32, 1),
) )
self.inject = ConditionInjector(256, 256)
def forward(self, x, cond): def forward(self, x, cond):
x = self.enc(x) x1, x2, x3, x4 = self.enc(x, cond)
x = self.inject(x, cond) x = self.dec(x1, x2, x3, x4, cond)
x = self.out_conv(x) x = self.out_conv(x)
return x return x
@ -28,15 +29,12 @@ class InputUpsample(nn.Module):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13), ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26 nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26), ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
nn.ELU(),
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32 nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32), ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
nn.ELU(),
) )
def forward(self, x): # [B, C, 13, 13] def forward(self, x): # [B, C, 13, 13]
@ -47,18 +45,14 @@ class GinkaInput(nn.Module):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__() super().__init__()
self.out_size = out_size self.out_size = out_size
self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1])
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch) self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1]) self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1])
self.inject1 = ConditionInjector(256, in_ch) self.inject1 = ConditionInjector(256, out_ch)
self.inject2 = ConditionInjector(256, out_ch) self.inject2 = ConditionInjector(256, out_ch)
self.inject3 = ConditionInjector(256, out_ch)
def forward(self, x, cond): def forward(self, x, cond):
x = self.enc1(x)
x = self.inject1(x, cond)
x = self.upsample(x) x = self.upsample(x)
x = self.inject1(x, cond)
x = self.enc(x)
x = self.inject2(x, cond) x = self.inject2(x, cond)
x = self.enc2(x)
x = self.inject3(x, cond)
return x return x

View File

@ -1,12 +1,7 @@
import math
from tqdm import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch_geometric.data import Data from torch_geometric.data import Data
from shared.graph import batch_convert_soft_map_to_graph
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from ..critic.model import MinamoModel
CLASS_NUM = 32 CLASS_NUM = 32
ILLEGAL_MAX_NUM = 30 ILLEGAL_MAX_NUM = 30
@ -156,15 +151,15 @@ def entrance_constraint_loss(
) )
return total_loss return total_loss
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)): def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
C = input_map.shape[1] C = input_map.shape[1]
mask = torch.ones(C, device=input_map.device) unallowed = get_not_allowed(allowed_classes, include_illegal=True)
mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1 illegal = input_map[:, unallowed, :, :]
illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel() penalty = torch.sum(illegal)
return illegal_class_penalty
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1): return penalty
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]):
wall_prob = input_map[:, wall_class] # [B, H, W] wall_prob = input_map[:, wall_class] # [B, H, W]
wall_ratio = wall_prob.mean() # 计算平均墙体占比 wall_ratio = wall_prob.mean() # 计算平均墙体占比
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚 wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
@ -241,6 +236,16 @@ def immutable_penalty_loss(
return penalty return penalty
def modifiable_penalty_loss(
probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
) -> torch.Tensor:
target_modifiable = input[:, modifiable_classes, :, :]
pred_modifiable = probs[:, modifiable_classes, :, :]
existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0)
penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device))
return penalty
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
not_allowed = get_not_allowed(legal_classes, include_illegal=True) not_allowed = get_not_allowed(legal_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :] input_mask = pred[:, not_allowed, :, :]
@ -249,43 +254,40 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
return penalty return penalty
class WGANGinkaLoss: class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]): def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]):
# weight: 判别器损失CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失 # weight:
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
# 2. CE 损失
# 3. 不可修改类型损失和非法图块损失
# 4. 图块类型损失
# 5. 入口存在性损失
# 6. 多样性损失
# 7. 密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数 self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight self.weight = weight
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond): def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
# 进行插值 # 进行插值
batch_size = real_data.size(0) batch_size = real_data.size(0)
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device) epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device)
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device) interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device)
# 对图像进行反向传播并计算梯度 # 对图像进行反向传播并计算梯度
interp_data.requires_grad_() interp_data.requires_grad_()
interp_graph.x.requires_grad_()
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond) d_score = critic(interp_data, stage, tag_cond, val_cond)
# 计算梯度 # 计算梯度
grad_vis = torch.autograd.grad( grad = torch.autograd.grad(
outputs=d_vis_score, inputs=interp_data, outputs=d_score, inputs=interp_data,
grad_outputs=torch.ones_like(d_vis_score), grad_outputs=torch.ones_like(d_score),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_topo = torch.autograd.grad(
outputs=d_topo_score, inputs=interp_graph.x,
grad_outputs=torch.ones_like(d_topo_score),
create_graph=True, retain_graph=True, only_inputs=True create_graph=True, retain_graph=True, only_inputs=True
)[0] )[0]
# 计算梯度的 L2 范数 # 计算梯度的 L2 范数
grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1) grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1)
grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1)
# 计算梯度惩罚项 # 计算梯度惩罚项
gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean() gp_loss = ((grad_norm - 1.0) ** 2).mean()
gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean()
gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item()) # print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
return gp_loss return gp_loss
@ -296,10 +298,8 @@ class WGANGinkaLoss:
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" 判别器损失函数 """ """ 判别器损失函数 """
fake_data = F.softmax(fake_data, dim=1) fake_data = F.softmax(fake_data, dim=1)
real_graph = batch_convert_soft_map_to_graph(real_data) real_scores = critic(real_data, stage, tag_cond, val_cond)
fake_graph = batch_convert_soft_map_to_graph(fake_data) fake_scores = critic(fake_data, stage, tag_cond, val_cond)
real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond)
fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond)
# Wasserstein 距离 # Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean() d_loss = fake_scores.mean() - real_scores.mean()
@ -312,10 +312,9 @@ class WGANGinkaLoss:
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """ """ 生成器损失函数 """
probs_fake = F.softmax(fake, dim=1) probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake) constraint_loss = inner_constraint_loss(probs_fake)
@ -343,9 +342,8 @@ class WGANGinkaLoss:
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor: def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1) probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores)
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage]) illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake) constraint_loss = inner_constraint_loss(probs_fake)
@ -370,10 +368,9 @@ class WGANGinkaLoss:
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor: def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1) probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake) constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
@ -395,13 +392,15 @@ class WGANGinkaLoss:
return sum(losses) return sum(losses)
def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor: def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
probs = F.softmax(map, dim=1)
head_scores = critic(probs, 0, tag_cond, val_cond)
probs_a, probs_b = probs.chunk(2, dim=0) probs_a, probs_b = probs.chunk(2, dim=0)
losses = [ losses = [
torch.mean(head_scores),
input_head_illegal_loss(probs), input_head_illegal_loss(probs),
input_head_wall_loss(probs), -js_divergence(probs_a, probs_b, softmax=False) * 0.1
-js_divergence(probs_a, probs_b, softmax=False) * 0.3
] ]
return sum(losses) return sum(losses)

View File

@ -1,22 +1,15 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..common.common import GCNBlock, DoubleConvBlock from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector from ..common.cond import ConditionInjector
class StageHead(nn.Module): class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)): def __init__(self, in_ch, out_ch, out_size=(13, 13)):
super().__init__() super().__init__()
self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch]) self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32)
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
self.pool = nn.Sequential( self.pool = nn.Sequential(
nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'), ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32),
nn.InstanceNorm2d(in_ch*2), ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
nn.ELU(),
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
nn.AdaptiveMaxPool2d(out_size), nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1) nn.Conv2d(in_ch, out_ch, 1)
@ -24,10 +17,7 @@ class StageHead(nn.Module):
self.inject = ConditionInjector(256, in_ch) self.inject = ConditionInjector(256, in_ch)
def forward(self, x, cond): def forward(self, x, cond):
x_cnn = self.cnn_head(x) x = self.dec(x)
x_gcn = self.gcn_head(x)
x = torch.cat([x_cnn, x_gcn], dim=1)
x = self.fusion(x)
x = self.inject(x, cond) x = self.inject(x, cond)
x = self.pool(x) x = self.pool(x)
return x return x

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from shared.attention import ChannelAttention from shared.attention import ChannelAttention
from ..common.common import GCNBlock, TransformerGCNBlock from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
from ..common.cond import ConditionInjector from ..common.cond import ConditionInjector
class GinkaTransformerEncoder(nn.Module): class GinkaTransformerEncoder(nn.Module):
@ -37,16 +37,17 @@ class GinkaTransformerEncoder(nn.Module):
class ConvBlock(nn.Module): class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, attn=True): def __init__(self, in_ch, out_ch, attn=True):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), # self.conv = nn.Sequential(
nn.InstanceNorm2d(out_ch), # nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.ELU(), # nn.InstanceNorm2d(out_ch),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), # nn.ELU(),
nn.InstanceNorm2d(out_ch), # nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
) # nn.InstanceNorm2d(out_ch),
if attn: # )
self.conv.append(ChannelAttention(out_ch)) # if attn:
self.conv.append(nn.ELU()) # self.conv.append(ChannelAttention(out_ch))
# self.conv.append(nn.ELU())
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
@ -64,47 +65,24 @@ class FusionModule(nn.Module):
class GinkaUNetInput(nn.Module): class GinkaUNetInput(nn.Module):
def __init__(self, in_ch, out_ch, w, h): def __init__(self, in_ch, out_ch, w, h):
super().__init__() super().__init__()
self.conv = ConvBlock(in_ch, in_ch) self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h)
self.fusion = ConvBlock(in_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x1 = self.conv(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.inject = ConditionInjector(256, out_ch) self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond): def forward(self, x, cond):
x = self.conv(x) x = self.conv(x)
x = self.pool(x)
x = self.inject(x, cond) x = self.inject(x, cond)
return x return x
class GinkaGCNFusedEncoder(nn.Module): class GinkaEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h): def __init__(self, in_ch, out_ch, w, h):
super().__init__() super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
self.fusion = FusionModule(out_ch*2, out_ch) self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch) self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond): def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x) x = self.pool(x)
x2 = self.gcn(x) x = self.conv(x)
x = self.fusion(x, x2)
x = self.inject(x, cond) x = self.inject(x, cond)
return x return x
@ -114,42 +92,29 @@ class GinkaUpSample(nn.Module):
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.InstanceNorm2d(out_ch), nn.InstanceNorm2d(out_ch),
nn.ELU(), nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.GELU()
) )
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class GinkaDecoder(nn.Module): class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaGCNFusedDecoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h): def __init__(self, in_ch, out_ch, w, h):
super().__init__() super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2) self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch) self.fusion = nn.Conv2d(in_ch, in_ch, 1)
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h) self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.fusion = FusionModule(out_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch) self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat, cond): def forward(self, x, feat, cond):
x = self.upsample(x) x = self.upsample(x)
x = torch.cat([x, feat], dim=1) x = torch.cat([x, feat], dim=1)
x = self.fusion(x)
x = self.conv(x) x = self.conv(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
x = self.inject(x, cond) x = self.inject(x, cond)
return x return x
@ -162,58 +127,62 @@ class GinkaBottleneck(nn.Module):
# ) # )
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4) # self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) # self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
self.conv = ConvBlock(module_ch, module_ch) self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h)
self.fusion = nn.Conv2d(module_ch*2, module_ch, 1)
self.inject = ConditionInjector(256, module_ch) self.inject = ConditionInjector(256, module_ch)
def forward(self, x, cond): def forward(self, x, cond):
B = x.size(0)
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] # x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
# x1 = self.transformer(x1) # x1 = self.transformer(x1)
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] # x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x1 = self.conv(x) x = self.conv(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
x = self.inject(x, cond) x = self.inject(x, cond)
return x return x
class GinkaUNet(nn.Module): class GinkaEncoderPath(nn.Module):
def __init__(self, in_ch=32, base_ch=64, out_ch=32): def __init__(self, in_ch, base_ch):
"""Ginka Model UNet 部分
"""
super().__init__() super().__init__()
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32) self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4) self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
)
def forward(self, x, cond): def forward(self, x, cond):
x1 = self.down1(x, cond) # [B, 64, 32, 32] x1 = self.down1(x, cond) # [B, 64, 32, 32]
x2 = self.down2(x1, cond) # [B, 128, 16, 16] x2 = self.down2(x1, cond) # [B, 128, 16, 16]
x3 = self.down3(x2, cond) # [B, 256, 8, 8] x3 = self.down3(x2, cond) # [B, 256, 8, 8]
x4 = self.down4(x3, cond) # [B, 512, 4, 4] x4 = self.down4(x3, cond) # [B, 512, 4, 4]
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
# 上采样 return x1, x2, x3, x4
class GinkaDecoderPath(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
def forward(self, x1, x2, x3, x4, cond):
x = self.up1(x4, x3, cond) # [B, 256, 8, 8] x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
x = self.up2(x, x2, cond) # [B, 128, 16, 16] x = self.up2(x, x2, cond) # [B, 128, 16, 16]
x = self.up3(x, x1, cond) # [B, 64, 32, 32] x = self.up3(x, x1, cond) # [B, 64, 32, 32]
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.enc = GinkaEncoderPath(in_ch, base_ch)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.dec = GinkaDecoderPath(base_ch)
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
def forward(self, x, cond):
x1, x2, x3, x4 = self.enc(x, cond)
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
x = self.dec(x1, x2, x3, x4, cond)
x = self.final(x) # [B, 32, 32, 32] x = self.final(x) # [B, 32, 32, 32]
return x return x

View File

@ -6,12 +6,13 @@ import torch
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import cv2 import cv2
import numpy as np
from torch_geometric.loader import DataLoader from torch_geometric.loader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from .generator.model import GinkaModel from .generator.model import GinkaModel
from .dataset import GinkaWGANDataset from .dataset import GinkaWGANDataset
from .generator.loss import WGANGinkaLoss from .generator.loss import WGANGinkaLoss
from .critic.model import MinamoModel from .critic.model import MinamoModel2
from shared.image import matrix_to_image_cv from shared.image import matrix_to_image_cv
# 标签定义: # 标签定义:
@ -105,7 +106,7 @@ def train():
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
ginka = GinkaModel().to(device) ginka = GinkaModel().to(device)
minamo = MinamoModel().to(device) minamo = MinamoModel2().to(device)
dataset = GinkaWGANDataset(args.train, device) dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device) dataset_val = GinkaWGANDataset(args.validate, device)
@ -113,7 +114,7 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs) # scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
@ -201,14 +202,24 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4: elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
if train_stage == 4:
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond) loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond) loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond) loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
dis_avg = (dis1 + dis2 + dis3) / 3.0 dis = [dis1, dis2, dis3]
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 loss_d = [loss_d1, loss_d2, loss_d3]
if train_stage == 4:
dis.append(dis0)
loss_d.append(loss_d0)
dis_avg = sum(dis) / len(dis)
loss_d_avg = sum(loss_d) / len(loss_d)
# 反向传播 # 反向传播
loss_d_avg.backward() loss_d_avg.backward()
@ -230,7 +241,7 @@ def train():
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond) loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond) loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3) loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
loss_g.backward() loss_g.backward()
@ -240,19 +251,16 @@ def train():
elif train_stage == 3 or train_stage == 4: elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4) fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
if train_stage == 3: loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond) loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond) loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
if train_stage == 4: if train_stage == 4:
loss_head = criterion.generator_input_head_loss(x_in) loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond)
loss_head.backward(retain_graph=True) loss_head.backward(retain_graph=True)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
loss_g.backward() loss_g.backward()
optimizer_ginka.step() optimizer_ginka.step()
loss_total_ginka += loss_g.detach() loss_total_ginka += loss_g.detach()
@ -286,6 +294,8 @@ def train():
}, f"result/wgan/minamo-{epoch + 1}.pth") }, f"result/wgan/minamo-{epoch + 1}.pth")
idx = 0 idx = 0
gap = 5
color = (255, 255, 255) # 白色
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1 = batch["real1"].to(device) real1 = batch["real1"].to(device)
@ -301,17 +311,42 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4: elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
x_in = torch.argmax(x_in, dim=1).cpu().numpy()
fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy() fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
masked1 = torch.argmax(masked1, dim=1).cpu().numpy()
masked2 = torch.argmax(masked2, dim=1).cpu().numpy()
masked3 = torch.argmax(masked3, dim=1).cpu().numpy()
for i in range(fake1.shape[0]): for i in range(fake1.shape[0]):
for key, one in enumerate([fake1, fake2, fake3]): fake1_img = matrix_to_image_cv(fake1[i], tile_dict)
map_matrix = one[i] fake2_img = matrix_to_image_cv(fake2[i], tile_dict)
image = matrix_to_image_cv(map_matrix, tile_dict) fake3_img = matrix_to_image_cv(fake3[i], tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image) if train_stage == 1 or train_stage == 2:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线
in1_img = matrix_to_image_cv(masked1[i], tile_dict)
in2_img = matrix_to_image_cv(masked2[i], tile_dict)
in3_img = matrix_to_image_cv(masked3[i], tile_dict)
img = np.block([
[[in1_img], [vline], [in2_img], [vline], [in3_img]],
[[hline]],
[[fake1_img], [vline], [fake2_img], [vline], [fake3_img]]
])
elif train_stage == 3 or train_stage == 4:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
in_img = matrix_to_image_cv(x_in[i], tile_dict)
img = np.block([
[[in_img], [vline], [fake1_img]],
[[hline]],
[[fake2_img], [vline], [fake3_img]]
])
cv2.imwrite(f"result/ginka_img/{idx}.png", img)
idx += 1 idx += 1