mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 12:57:15 +08:00
refactor: 修改目录结构 & feat: 条件注入
This commit is contained in:
parent
a28e56456d
commit
44b90e7630
43
ginka/common/cond.py
Normal file
43
ginka/common/cond.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConditionEncoder(nn.Module):
|
||||
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
||||
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.LayerNorm(hidden_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(hidden_dim*2, hidden_dim*4),
|
||||
nn.LayerNorm(hidden_dim*4),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(hidden_dim*4, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, tag, val):
|
||||
tag = self.tag_embed(tag)
|
||||
val = self.val_embed(val)
|
||||
feat = torch.cat([tag, val], dim=1)
|
||||
feat = self.fusion(feat)
|
||||
return feat
|
||||
|
||||
class ConditionInjector(nn.Module):
|
||||
def __init__(self, cond_dim, out_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(cond_dim, cond_dim*2),
|
||||
nn.LayerNorm(cond_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(cond_dim*2, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
cond = self.fc(cond)
|
||||
B, D = cond.shape
|
||||
cond = cond.view(B, D, 1, 1)
|
||||
return x + cond
|
||||
@ -3,30 +3,17 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from ..common.cond import ConditionEncoder, ConditionInjector
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
# 视觉相似度部分
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 拓扑相似度部分
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
|
||||
def forward(self, map, graph):
|
||||
vision_feat = self.vision_model(map)
|
||||
topo_feat = self.topo_model(graph)
|
||||
|
||||
return vision_feat, topo_feat
|
||||
|
||||
class CNNHead(nn.Module):
|
||||
def __init__(self, in_ch, out_dim):
|
||||
def __init__(self, in_ch):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
|
||||
@ -35,61 +22,69 @@ class CNNHead(nn.Module):
|
||||
nn.AdaptiveMaxPool2d((2, 2))
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*2*2, out_dim))
|
||||
spectral_norm(nn.Linear(in_ch*2*2, 1))
|
||||
)
|
||||
self.proj = nn.Linear(256, in_ch*2*2)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, cond):
|
||||
x = self.cnn(x)
|
||||
B, C, H, W = x.shape
|
||||
x = x.view(B, -1)
|
||||
x = self.fc(x)
|
||||
cond = self.proj(cond)
|
||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
||||
x = self.fc(x) + proj
|
||||
return x
|
||||
|
||||
class GCNHead(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
def __init__(self, in_dim):
|
||||
super().__init__()
|
||||
self.gcn = GCNConv(in_dim, in_dim)
|
||||
self.proj = nn.Linear(256, in_dim)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, out_dim))
|
||||
spectral_norm(nn.Linear(in_dim, 1))
|
||||
)
|
||||
|
||||
def forward(self, x, graph):
|
||||
def forward(self, x, graph, cond):
|
||||
x = self.gcn(x, graph.edge_index)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = global_max_pool(x, graph.batch)
|
||||
x = self.fc(x)
|
||||
cond = self.proj(cond)
|
||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
||||
x = self.fc(x) + proj
|
||||
return x
|
||||
|
||||
class MinamoScoreHead(nn.Module):
|
||||
def __init__(self, vision_dim, topo_dim, out_dim):
|
||||
def __init__(self, vision_dim, topo_dim):
|
||||
super().__init__()
|
||||
self.vision_head = CNNHead(vision_dim, out_dim)
|
||||
self.topo_head = GCNHead(topo_dim, out_dim)
|
||||
self.vision_head = CNNHead(vision_dim)
|
||||
self.topo_head = GCNHead(topo_dim)
|
||||
|
||||
def forward(self, vis, topo, graph):
|
||||
vis_score = self.vision_head(vis)
|
||||
topo_score = self.topo_head(topo, graph)
|
||||
def forward(self, vis, topo, graph, cond):
|
||||
vis_score = self.vision_head(vis, cond)
|
||||
topo_score = self.topo_head(topo, graph, cond)
|
||||
return vis_score, topo_score
|
||||
|
||||
class MinamoScoreModule(nn.Module):
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
self.cond = ConditionEncoder(64, 16, 128, 256)
|
||||
# 输出层
|
||||
self.head1 = MinamoScoreHead(512, 512, 1)
|
||||
self.head2 = MinamoScoreHead(512, 512, 1)
|
||||
self.head3 = MinamoScoreHead(512, 512, 1)
|
||||
self.head1 = MinamoScoreHead(512, 512)
|
||||
self.head2 = MinamoScoreHead(512, 512)
|
||||
self.head3 = MinamoScoreHead(512, 512)
|
||||
|
||||
def forward(self, map, graph, stage):
|
||||
def forward(self, map, graph, stage, tag_cond, val_cond):
|
||||
vision = self.vision_model(map)
|
||||
topo = self.topo_model(graph)
|
||||
cond = self.cond(tag_cond, val_cond)
|
||||
if stage == 1:
|
||||
vision_score, topo_score = self.head1(vision, topo, graph)
|
||||
vision_score, topo_score = self.head1(vision, topo, graph, cond)
|
||||
elif stage == 2:
|
||||
vision_score, topo_score = self.head2(vision, topo, graph)
|
||||
vision_score, topo_score = self.head2(vision, topo, graph, cond)
|
||||
elif stage == 3:
|
||||
vision_score, topo_score = self.head3(vision, topo, graph)
|
||||
vision_score, topo_score = self.head3(vision, topo, graph, cond)
|
||||
else:
|
||||
raise RuntimeError("Unknown critic stage.")
|
||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
||||
@ -98,19 +93,22 @@ class MinamoScoreModule(nn.Module):
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
input = torch.randn((1, 32, 13, 13)).cuda()
|
||||
tag = torch.rand(1, 64).cuda()
|
||||
val = torch.rand(1, 16).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = MinamoScoreModule().cuda()
|
||||
model = MinamoModel().cuda()
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1)
|
||||
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
||||
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
|
||||
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
|
||||
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
|
||||
@ -87,54 +87,94 @@ class GinkaWGANDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def handle_stage1(self, target):
|
||||
def handle_stage1(self, target, tag_cond, val_cond):
|
||||
# 课程学习第一阶段,蒙版填充
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
return {
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"real2": removed2,
|
||||
"masked2": masked2,
|
||||
"real3": removed3,
|
||||
"masked3": masked3,
|
||||
"tag_cond": tag_cond,
|
||||
"val_cond": val_cond
|
||||
}
|
||||
|
||||
def handle_stage2(self, target):
|
||||
def handle_stage2(self, target, tag_cond, val_cond):
|
||||
# 课程学习第二阶段,完全随机蒙版
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
return {
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"real2": removed2,
|
||||
"masked2": masked2,
|
||||
"real3": removed3,
|
||||
"masked3": masked3,
|
||||
"tag_cond": tag_cond,
|
||||
"val_cond": val_cond
|
||||
}
|
||||
|
||||
def handle_stage3(self, target):
|
||||
def handle_stage3(self, target, tag_cond, val_cond):
|
||||
# 第三阶段,联合生成,输入随机蒙版
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
return {
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"real2": removed2,
|
||||
"masked2": torch.zeros_like(target),
|
||||
"real3": removed3,
|
||||
"masked3": torch.zeros_like(target),
|
||||
"tag_cond": tag_cond,
|
||||
"val_cond": val_cond
|
||||
}
|
||||
|
||||
def handle_stage4(self, target):
|
||||
# 第四阶段,与第二阶段交替进行,完全随机输入
|
||||
def handle_stage4(self, target, tag_cond, val_cond):
|
||||
# 第四阶段,完全随机输入
|
||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
return removed1, rand, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
return {
|
||||
"real1": removed1,
|
||||
"masked1": rand,
|
||||
"real2": removed2,
|
||||
"masked2": torch.zeros_like(target),
|
||||
"real3": removed3,
|
||||
"masked3": torch.zeros_like(target),
|
||||
"tag_cond": tag_cond,
|
||||
"val_cond": val_cond
|
||||
}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
tag_cond = torch.FloatTensor(item['tag'])
|
||||
val_cond = torch.FloatTensor(item['val'])
|
||||
|
||||
if self.train_stage == 1:
|
||||
return self.handle_stage1(target)
|
||||
return self.handle_stage1(target, tag_cond, val_cond)
|
||||
|
||||
elif self.train_stage == 2:
|
||||
return self.handle_stage2(target)
|
||||
return self.handle_stage2(target, tag_cond, val_cond)
|
||||
|
||||
elif self.train_stage == 3:
|
||||
return self.handle_stage3(target)
|
||||
return self.handle_stage3(target, tag_cond, val_cond)
|
||||
|
||||
elif self.train_stage == 4:
|
||||
return self.handle_stage4(target)
|
||||
return self.handle_stage4(target, tag_cond, val_cond)
|
||||
|
||||
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
|
||||
|
||||
@ -1,29 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..common.common import GCNBlock, DoubleConvBlock
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class RandomInputHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(32, 32, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(32, 64, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(64),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(64, 128, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(128),
|
||||
self.conv = DoubleConvBlock([32, 64, 128])
|
||||
self.gcn = GCNBlock(32, 128, 128, 32, 32)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(256),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(128),
|
||||
nn.ELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d((13, 13)),
|
||||
nn.Conv2d(128, 32, 1),
|
||||
)
|
||||
self.inject = ConditionInjector(256, 256)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
def forward(self, x, cond):
|
||||
x_cnn = self.conv(x)
|
||||
x_gcn = self.gcn(x)
|
||||
x = torch.cat([x_cnn, x_gcn], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.inject(x, cond)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
@ -4,11 +4,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.data import Data
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from shared.similarity.topo import overall_similarity, build_topological_graph
|
||||
from shared.similarity.vision import calculate_visual_similarity
|
||||
from ..critic.model import MinamoModel
|
||||
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 13
|
||||
@ -355,7 +353,7 @@ class WGANGinkaLoss:
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
def compute_gradient_penalty(self, critic, stage, real_data, fake_data):
|
||||
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
|
||||
# 进行插值
|
||||
batch_size = real_data.size(0)
|
||||
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
|
||||
@ -366,7 +364,7 @@ class WGANGinkaLoss:
|
||||
interp_data.requires_grad_()
|
||||
interp_graph.x.requires_grad_()
|
||||
|
||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage)
|
||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond)
|
||||
|
||||
# 计算梯度
|
||||
grad_vis = torch.autograd.grad(
|
||||
@ -392,29 +390,30 @@ class WGANGinkaLoss:
|
||||
return gp_loss
|
||||
|
||||
def discriminator_loss(
|
||||
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
|
||||
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor,
|
||||
tag_cond: torch.Tensor, val_cond: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
""" 判别器损失函数 """
|
||||
fake_data = F.softmax(fake_data, dim=1)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||
real_scores, _, _ = critic(real_data, real_graph, stage)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph, stage)
|
||||
real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond)
|
||||
|
||||
# Wasserstein 距离
|
||||
d_loss = fake_scores.mean() - real_scores.mean()
|
||||
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data)
|
||||
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data, tag_cond, val_cond)
|
||||
|
||||
total_loss = d_loss + self.lambda_gp * grad_loss
|
||||
|
||||
return total_loss, d_loss
|
||||
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" 生成器损失函数 """
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
@ -439,11 +438,11 @@ class WGANGinkaLoss:
|
||||
|
||||
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||
|
||||
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
|
||||
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
|
||||
@ -462,11 +461,11 @@ class WGANGinkaLoss:
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor:
|
||||
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .output import GinkaOutput
|
||||
from .input import GinkaInput, RandomInputHead
|
||||
from ..common.cond import ConditionEncoder
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
@ -14,23 +15,27 @@ class GinkaModel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.head = RandomInputHead()
|
||||
self.cond = ConditionEncoder(64, 16, 128, 256)
|
||||
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, stage, random=False):
|
||||
def forward(self, x, stage, tag_cond, val_cond, random=False):
|
||||
cond = self.cond(tag_cond, val_cond)
|
||||
if random:
|
||||
x_in = F.softmax(self.head(x), dim=1)
|
||||
x_in = F.softmax(self.head(x, cond), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
x = self.input(x_in)
|
||||
x = self.unet(x)
|
||||
x = self.output(x, stage)
|
||||
x = self.unet(x, cond)
|
||||
x = self.output(x, stage, cond)
|
||||
return x, x_in
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
input = torch.randn((1, 32, 32, 32)).cuda()
|
||||
input = torch.rand(1, 32, 32, 32).cuda()
|
||||
tag = torch.rand(1, 64).cuda()
|
||||
val = torch.rand(1, 16).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaModel().cuda()
|
||||
@ -38,12 +43,14 @@ if __name__ == "__main__":
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, _ = model(input, 1, True)
|
||||
output, _ = model(input, 1, tag, val, True)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
print(f"Head parameters: {sum(p.numel() for p in model.head.parameters())}")
|
||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .common import GCNBlock, DoubleConvBlock
|
||||
from ..common.common import GCNBlock, DoubleConvBlock
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class StageHead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
||||
@ -9,15 +10,21 @@ class StageHead(nn.Module):
|
||||
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
|
||||
self.pool = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
)
|
||||
self.inject = ConditionInjector(256, in_ch)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, cond):
|
||||
x_cnn = self.cnn_head(x)
|
||||
x_gcn = self.gcn_head(x)
|
||||
x = torch.cat([x_cnn, x_gcn], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.inject(x, cond)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
@ -28,13 +35,13 @@ class GinkaOutput(nn.Module):
|
||||
self.head2 = StageHead(in_ch, out_ch, out_size)
|
||||
self.head3 = StageHead(in_ch, out_ch, out_size)
|
||||
|
||||
def forward(self, x, stage):
|
||||
def forward(self, x, stage, cond):
|
||||
if stage == 1:
|
||||
x = self.head1(x)
|
||||
x = self.head1(x, cond)
|
||||
elif stage == 2:
|
||||
x = self.head2(x)
|
||||
x = self.head2(x, cond)
|
||||
elif stage == 3:
|
||||
x = self.head3(x)
|
||||
x = self.head3(x, cond)
|
||||
else:
|
||||
raise RuntimeError("Unknown generate stage.")
|
||||
return x
|
||||
@ -2,7 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from shared.attention import ChannelAttention
|
||||
from .common import GCNBlock, DoubleConvBlock
|
||||
from ..common.common import GCNBlock
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class GinkaTransformerEncoder(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
||||
@ -53,7 +54,7 @@ class ConvBlock(nn.Module):
|
||||
class FusionModule(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
|
||||
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
@ -66,10 +67,12 @@ class GinkaEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, cond):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaGCNFusedEncoder(nn.Module):
|
||||
@ -79,12 +82,14 @@ class GinkaGCNFusedEncoder(nn.Module):
|
||||
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, cond):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x2 = self.gcn(x)
|
||||
x = self.fusion(x, x2)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaUpSample(nn.Module):
|
||||
@ -105,11 +110,13 @@ class GinkaDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, feat):
|
||||
def forward(self, x, feat, cond):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.conv(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaGCNFusedDecoder(nn.Module):
|
||||
@ -119,13 +126,15 @@ class GinkaGCNFusedDecoder(nn.Module):
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, feat):
|
||||
def forward(self, x, feat, cond):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
x = self.fusion(x, x2)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaBottleneck(nn.Module):
|
||||
@ -136,9 +145,10 @@ class GinkaBottleneck(nn.Module):
|
||||
token_size=16, ff_dim=1024, num_layers=4
|
||||
)
|
||||
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||
self.fusion = FusionModule(module_ch*2, module_ch)
|
||||
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
||||
self.inject = ConditionInjector(256, module_ch)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, cond):
|
||||
B = x.size(0)
|
||||
|
||||
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
||||
@ -146,7 +156,9 @@ class GinkaBottleneck(nn.Module):
|
||||
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
||||
x2 = self.gcn(x)
|
||||
|
||||
x = self.fusion(x1, x2)
|
||||
x = torch.cat([x, x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.inject(x, cond)
|
||||
|
||||
return x
|
||||
|
||||
@ -162,7 +174,7 @@ class GinkaUNet(nn.Module):
|
||||
self.down1 = ConvBlock(in_ch, base_ch)
|
||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4)
|
||||
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
||||
|
||||
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
|
||||
@ -175,17 +187,17 @@ class GinkaUNet(nn.Module):
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, cond):
|
||||
x1 = self.down1(x) # [B, 64, 32, 32]
|
||||
x2 = self.down2(x1) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2) # [B, 256, 8, 8]
|
||||
x4 = self.down4(x3) # [B, 512, 4, 4]
|
||||
x4 = self.bottleneck(x4) # [B, 512, 4, 4]
|
||||
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
|
||||
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
|
||||
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
|
||||
|
||||
# 上采样
|
||||
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
||||
x = self.up2(x, x2) # [B, 128, 16, 16]
|
||||
x = self.up3(x, x1) # [B, 64, 32, 32]
|
||||
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
|
||||
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
|
||||
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
|
||||
x = self.final(x) # [B, 32, 32, 32]
|
||||
|
||||
return x
|
||||
@ -8,13 +8,38 @@ import torch.nn.functional as F
|
||||
import cv2
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import GinkaModel
|
||||
from .generator.model import GinkaModel
|
||||
from .dataset import GinkaWGANDataset
|
||||
from .model.loss import WGANGinkaLoss
|
||||
from .model.input import RandomInputHead
|
||||
from minamo.model.model import MinamoScoreModule
|
||||
from .generator.loss import WGANGinkaLoss
|
||||
from .generator.input import RandomInputHead
|
||||
from .critic.model import MinamoModel
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# 标签定义:
|
||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具塔
|
||||
|
||||
# 标量值定义:
|
||||
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
|
||||
# 1. 怪物密度,怪物数量/地图面积
|
||||
# 2. 资源密度,资源数量/地图面积
|
||||
# 3. 门密度,门数量/地图面积
|
||||
# 4. 入口数量
|
||||
|
||||
# 图块定义:
|
||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
||||
# 10-12. 三种等级的红宝石
|
||||
# 13-15. 三种等级的蓝宝石
|
||||
# 16-18. 三种等级的绿宝石
|
||||
# 19-21. 三种等级的血瓶
|
||||
# 22-24. 三种等级的道具
|
||||
# 25-27. 三种等级的怪物
|
||||
# 28-29. 留空
|
||||
# 30. 楼梯入口
|
||||
# 31. 箭头入口
|
||||
|
||||
BATCH_SIZE = 16
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -34,27 +59,28 @@ def parse_arguments():
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
|
||||
parser.add_argument("--tuning", type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
fake1, _ = gen(masked1, 1)
|
||||
fake2, _ = gen(masked2, 2)
|
||||
fake3, _ = gen(masked3, 3)
|
||||
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
fake1, _ = gen(masked1, 1, False, tag, val)
|
||||
fake2, _ = gen(masked2, 2, False, tag, val)
|
||||
fake3, _ = gen(masked3, 3, False, tag, val)
|
||||
if detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
|
||||
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
||||
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
||||
if progress_detach:
|
||||
fake1, x_in = gen(input.detach(), 1, random)
|
||||
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2)
|
||||
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3)
|
||||
fake1, x_in = gen(input.detach(), 1, random, tag, val)
|
||||
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, False, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, False, tag, val)
|
||||
else:
|
||||
fake1, x_in = gen(input, 1, random)
|
||||
fake2, _ = gen(F.softmax(fake1, dim=1), 2)
|
||||
fake3, _ = gen(F.softmax(fake2, dim=1), 3)
|
||||
fake1, x_in = gen(input, 1, random, tag, val)
|
||||
fake2, _ = gen(F.softmax(fake1, dim=1), 2, False, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2, dim=1), 3, False, tag, val)
|
||||
if result_detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
|
||||
else:
|
||||
@ -74,7 +100,7 @@ def train():
|
||||
|
||||
ginka = GinkaModel().to(device)
|
||||
ginka_head = RandomInputHead().to(device)
|
||||
minamo = MinamoScoreModule().to(device)
|
||||
minamo = MinamoModel().to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||
@ -133,6 +159,14 @@ def train():
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
curr_epoch = args.curr_epoch
|
||||
|
||||
if args.tuning:
|
||||
train_stage = 1
|
||||
curr_epoch = curr_epoch // 4
|
||||
stage_epoch = 0
|
||||
mask_ratio = 0.2
|
||||
|
||||
low_loss_epochs = 0
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||
@ -142,7 +176,14 @@ def train():
|
||||
loss_ce_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
real1 = batch["real1"].to(device)
|
||||
masked1 = batch["masked1"].to(device)
|
||||
real2 = batch["real2"].to(device)
|
||||
masked2 = batch["masked2"].to(device)
|
||||
real3 = batch["real3"].to(device)
|
||||
masked3 = batch["masked3"].to(device)
|
||||
tag_cond = batch["tag_cond"].to(device)
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
|
||||
# ---------- 训练判别器
|
||||
for _ in range(c_steps):
|
||||
@ -152,10 +193,10 @@ def train():
|
||||
|
||||
with torch.no_grad():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4)
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
@ -235,7 +276,7 @@ def train():
|
||||
if train_stage == 5:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= args.curr_epoch:
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
@ -283,13 +324,21 @@ def train():
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
real1 = batch["real1"].to(device)
|
||||
masked1 = batch["masked1"].to(device)
|
||||
real2 = batch["real2"].to(device)
|
||||
masked2 = batch["masked2"].to(device)
|
||||
real3 = batch["real3"].to(device)
|
||||
masked3 = batch["masked3"].to(device)
|
||||
tag_cond = batch["tag_cond"].to(device)
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, input, True, True)
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from shared.graph import differentiable_convert_to_data
|
||||
from shared.utils import random_smooth_onehot
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = []
|
||||
for value in data["data"].values():
|
||||
data_list.append(value)
|
||||
|
||||
return data_list
|
||||
|
||||
class MinamoDataset(Dataset):
|
||||
def __init__(self, data_path: str):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
min_main = random.uniform(0.6, 1)
|
||||
max_main = random.uniform(0.8, 1)
|
||||
epsilon = random.uniform(0, 0.4)
|
||||
|
||||
map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon)
|
||||
map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)
|
||||
|
||||
graph1 = differentiable_convert_to_data(map1_probs)
|
||||
graph2 = differentiable_convert_to_data(map2_probs)
|
||||
|
||||
return (
|
||||
map1_probs,
|
||||
map2_probs,
|
||||
torch.FloatTensor([item['visionSimilarity']]),
|
||||
torch.FloatTensor([item['topoSimilarity']]),
|
||||
graph1,
|
||||
graph2
|
||||
)
|
||||
@ -1,17 +0,0 @@
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=0.2, topo_weight=0.8):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
self.loss = nn.L1Loss()
|
||||
|
||||
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
|
||||
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
|
||||
vis_loss = self.loss(vis_pred, vis_true)
|
||||
topo_loss = self.loss(topo_pred, topo_true)
|
||||
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f} | {vis_loss.item():.12f}, {topo_loss.item():.12f}")
|
||||
# print(vis_loss.item(), topo_loss.item())
|
||||
return self.vision_weight * vis_loss + self.topo_weight * topo_loss
|
||||
@ -1,83 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoSimilarityVision(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch * 2, 3, padding=1),
|
||||
nn.InstanceNorm2d(in_ch * 2),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1),
|
||||
nn.InstanceNorm2d(in_ch * 4),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(in_ch * 4, in_ch * 8, 3),
|
||||
nn.InstanceNorm2d(in_ch * 8),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_ch * 8, out_ch),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class MinamoSimilarityTopo(nn.Module):
|
||||
def __init__(self, in_ch, hidden_dim, out_ch):
|
||||
super().__init__()
|
||||
self.input_fc = nn.Sequential(
|
||||
nn.Linear(in_ch, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.conv1 = GCNConv(hidden_dim, hidden_dim*2)
|
||||
self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4)
|
||||
self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_dim*2)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim*4)
|
||||
self.norm3 = nn.LayerNorm(hidden_dim*8)
|
||||
|
||||
self.output_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim*8, out_ch)
|
||||
)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.input_fc(graph.x)
|
||||
|
||||
x = self.conv1(x, graph.edge_index)
|
||||
x = F.relu(self.norm1(x))
|
||||
|
||||
x = self.conv2(x, graph.edge_index)
|
||||
x = F.relu(self.norm2(x))
|
||||
|
||||
x = self.conv3(x, graph.edge_index)
|
||||
x = F.relu(self.norm3(x))
|
||||
|
||||
x = global_mean_pool(x, graph.batch)
|
||||
x = self.output_fc(x)
|
||||
|
||||
return x
|
||||
|
||||
class MinamoSimilarityModel(nn.Module):
|
||||
def __init__(self, tile_type=32):
|
||||
super().__init__()
|
||||
self.vision = MinamoSimilarityVision(tile_type, 512)
|
||||
self.topo = MinamoSimilarityTopo(tile_type, 64, 512)
|
||||
|
||||
def forward(self, x, graph):
|
||||
vis_feat = self.vision(x)
|
||||
topo_feat = self.topo(graph)
|
||||
return vis_feat, topo_feat
|
||||
|
||||
153
minamo/train.py
153
minamo/train.py
@ -1,153 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
from .model.loss import MinamoLoss
|
||||
from .dataset import MinamoDataset
|
||||
from shared.args import parse_arguments
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||
disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json')
|
||||
|
||||
model = MinamoModel(32)
|
||||
model.to(device)
|
||||
|
||||
# 准备数据集
|
||||
dataset = MinamoDataset(args.train)
|
||||
val_dataset = MinamoDataset(args.validate)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=64,
|
||||
shuffle=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=64,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"], strict=False)
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if 'ins' not in name: # 仅训练扩展部分
|
||||
# param.requires_grad = False
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(args.epochs), disable=disable_tqdm):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
# if epoch == 30:
|
||||
# for name, param in model.named_parameters():
|
||||
# param.requires_grad = True
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, disable=disable_tqdm):
|
||||
# 数据迁移到设备
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||
map2 = map2.to(device)
|
||||
topo_simi = topo_simi.to(device)
|
||||
vision_simi = vision_simi.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
if map1.shape[0] == 1:
|
||||
continue
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
ave_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
# if p.grad is not None:
|
||||
# param_norm = p.grad.detach().data.norm(2)
|
||||
# total_norm += param_norm.item() ** 2
|
||||
# total_norm = total_norm ** 0.5
|
||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler.step()
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 5 == 0:
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(val_loader, leave=False, disable=disable_tqdm):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
||||
map1_val = map1_val.to(device)
|
||||
map2_val = map2_val.to(device)
|
||||
vision_simi_val = vision_simi_val.to(device)
|
||||
topo_simi_val = topo_simi_val.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, "result/minamo.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(2)
|
||||
train()
|
||||
@ -1,61 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
from .model.loss import MinamoLoss
|
||||
from .dataset import MinamoDataset
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = MinamoModel(32)
|
||||
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
model.to(device)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
print(f"Layer: {name}, Params: {param.numel()}")
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Total parameters: {total_params}")
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = MinamoDataset("datasets/minamo-eval-1.json")
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
criterion = MinamoLoss()
|
||||
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(val_loader):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
||||
map1_val = map1_val.to(device)
|
||||
map2_val = map2_val.to(device)
|
||||
vision_simi_val = vision_simi_val.to(device)
|
||||
topo_simi_val = topo_simi_val.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||
|
||||
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
loss_val = criterion(
|
||||
vision_pred_val, topo_pred_val,
|
||||
vision_simi_val, topo_simi_val
|
||||
)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(2)
|
||||
validate()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user