From 44b90e7630bf2c3b4bcb49f5be0871fa5a853360 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 29 Apr 2025 18:24:01 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=E7=9B=AE?= =?UTF-8?q?=E5=BD=95=E7=BB=93=E6=9E=84=20&=20feat:=20=E6=9D=A1=E4=BB=B6?= =?UTF-8?q?=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/{model => common}/common.py | 0 ginka/common/cond.py | 43 +++++++ {minamo/model => ginka/critic}/model.py | 78 ++++++------ {minamo/model => ginka/critic}/topo.py | 0 {minamo/model => ginka/critic}/vision.py | 0 ginka/dataset.py | 66 ++++++++-- ginka/{model => generator}/input.py | 31 +++-- ginka/{model => generator}/loss.py | 29 +++-- ginka/{model => generator}/model.py | 19 ++- ginka/{model => generator}/output.py | 19 ++- ginka/{model => generator}/unet.py | 48 ++++--- ginka/train_wgan.py | 95 ++++++++++---- minamo/dataset.py | 49 -------- minamo/model/loss.py | 17 --- minamo/model/similarity.py | 83 ------------ minamo/train.py | 153 ----------------------- minamo/validate.py | 61 --------- 17 files changed, 294 insertions(+), 497 deletions(-) rename ginka/{model => common}/common.py (100%) create mode 100644 ginka/common/cond.py rename {minamo/model => ginka/critic}/model.py (66%) rename {minamo/model => ginka/critic}/topo.py (100%) rename {minamo/model => ginka/critic}/vision.py (100%) rename ginka/{model => generator}/input.py (62%) rename ginka/{model => generator}/loss.py (96%) rename ginka/{model => generator}/model.py (69%) rename ginka/{model => generator}/output.py (66%) rename ginka/{model => generator}/unet.py (78%) delete mode 100644 minamo/dataset.py delete mode 100644 minamo/model/loss.py delete mode 100644 minamo/model/similarity.py delete mode 100644 minamo/train.py delete mode 100644 minamo/validate.py diff --git a/ginka/model/common.py b/ginka/common/common.py similarity index 100% rename from ginka/model/common.py rename to ginka/common/common.py diff --git a/ginka/common/cond.py b/ginka/common/cond.py new file mode 100644 index 0000000..c2a3a9c --- /dev/null +++ b/ginka/common/cond.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ConditionEncoder(nn.Module): + def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): + super().__init__() + self.tag_embed = nn.Linear(tag_dim, hidden_dim) + self.val_embed = nn.Linear(val_dim, hidden_dim) + self.fusion = nn.Sequential( + nn.LayerNorm(hidden_dim*2), + nn.ELU(), + + nn.Linear(hidden_dim*2, hidden_dim*4), + nn.LayerNorm(hidden_dim*4), + nn.ELU(), + + nn.Linear(hidden_dim*4, out_dim) + ) + + def forward(self, tag, val): + tag = self.tag_embed(tag) + val = self.val_embed(val) + feat = torch.cat([tag, val], dim=1) + feat = self.fusion(feat) + return feat + +class ConditionInjector(nn.Module): + def __init__(self, cond_dim, out_dim): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(cond_dim, cond_dim*2), + nn.LayerNorm(cond_dim*2), + nn.ELU(), + + nn.Linear(cond_dim*2, out_dim) + ) + + def forward(self, x, cond): + cond = self.fc(cond) + B, D = cond.shape + cond = cond.view(B, D, 1, 1) + return x + cond diff --git a/minamo/model/model.py b/ginka/critic/model.py similarity index 66% rename from minamo/model/model.py rename to ginka/critic/model.py index 702cd64..dfe45a0 100644 --- a/minamo/model/model.py +++ b/ginka/critic/model.py @@ -3,30 +3,17 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool -from .vision import MinamoVisionModel -from .topo import MinamoTopoModel from shared.constant import VISION_WEIGHT, TOPO_WEIGHT from shared.graph import batch_convert_soft_map_to_graph +from .vision import MinamoVisionModel +from .topo import MinamoTopoModel +from ..common.cond import ConditionEncoder, ConditionInjector def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") -class MinamoModel(nn.Module): - def __init__(self, tile_types=32): - super().__init__() - # 视觉相似度部分 - self.vision_model = MinamoVisionModel(tile_types) - # 拓扑相似度部分 - self.topo_model = MinamoTopoModel(tile_types) - - def forward(self, map, graph): - vision_feat = self.vision_model(map) - topo_feat = self.topo_model(graph) - - return vision_feat, topo_feat - class CNNHead(nn.Module): - def __init__(self, in_ch, out_dim): + def __init__(self, in_ch): super().__init__() self.cnn = nn.Sequential( spectral_norm(nn.Conv2d(in_ch, in_ch, 3)), @@ -35,61 +22,69 @@ class CNNHead(nn.Module): nn.AdaptiveMaxPool2d((2, 2)) ) self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_ch*2*2, out_dim)) + spectral_norm(nn.Linear(in_ch*2*2, 1)) ) + self.proj = nn.Linear(256, in_ch*2*2) - def forward(self, x): + def forward(self, x, cond): x = self.cnn(x) B, C, H, W = x.shape x = x.view(B, -1) - x = self.fc(x) + cond = self.proj(cond) + proj = torch.sum(x * cond, dim=1, keepdim=True) + x = self.fc(x) + proj return x class GCNHead(nn.Module): - def __init__(self, in_dim, out_dim): + def __init__(self, in_dim): super().__init__() self.gcn = GCNConv(in_dim, in_dim) + self.proj = nn.Linear(256, in_dim) self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_dim, out_dim)) + spectral_norm(nn.Linear(in_dim, 1)) ) - def forward(self, x, graph): + def forward(self, x, graph, cond): x = self.gcn(x, graph.edge_index) x = F.leaky_relu(x, 0.2) x = global_max_pool(x, graph.batch) - x = self.fc(x) + cond = self.proj(cond) + proj = torch.sum(x * cond, dim=1, keepdim=True) + x = self.fc(x) + proj return x class MinamoScoreHead(nn.Module): - def __init__(self, vision_dim, topo_dim, out_dim): + def __init__(self, vision_dim, topo_dim): super().__init__() - self.vision_head = CNNHead(vision_dim, out_dim) - self.topo_head = GCNHead(topo_dim, out_dim) + self.vision_head = CNNHead(vision_dim) + self.topo_head = GCNHead(topo_dim) - def forward(self, vis, topo, graph): - vis_score = self.vision_head(vis) - topo_score = self.topo_head(topo, graph) + def forward(self, vis, topo, graph, cond): + vis_score = self.vision_head(vis, cond) + topo_score = self.topo_head(topo, graph, cond) return vis_score, topo_score -class MinamoScoreModule(nn.Module): +class MinamoModel(nn.Module): def __init__(self, tile_types=32): super().__init__() self.topo_model = MinamoTopoModel(tile_types) self.vision_model = MinamoVisionModel(tile_types) + self.cond = ConditionEncoder(64, 16, 128, 256) # 输出层 - self.head1 = MinamoScoreHead(512, 512, 1) - self.head2 = MinamoScoreHead(512, 512, 1) - self.head3 = MinamoScoreHead(512, 512, 1) + self.head1 = MinamoScoreHead(512, 512) + self.head2 = MinamoScoreHead(512, 512) + self.head3 = MinamoScoreHead(512, 512) - def forward(self, map, graph, stage): + def forward(self, map, graph, stage, tag_cond, val_cond): vision = self.vision_model(map) topo = self.topo_model(graph) + cond = self.cond(tag_cond, val_cond) if stage == 1: - vision_score, topo_score = self.head1(vision, topo, graph) + vision_score, topo_score = self.head1(vision, topo, graph, cond) elif stage == 2: - vision_score, topo_score = self.head2(vision, topo, graph) + vision_score, topo_score = self.head2(vision, topo, graph, cond) elif stage == 3: - vision_score, topo_score = self.head3(vision, topo, graph) + vision_score, topo_score = self.head3(vision, topo, graph, cond) else: raise RuntimeError("Unknown critic stage.") score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score @@ -98,19 +93,22 @@ class MinamoScoreModule(nn.Module): # 检查显存占用 if __name__ == "__main__": input = torch.randn((1, 32, 13, 13)).cuda() + tag = torch.rand(1, 64).cuda() + val = torch.rand(1, 16).cuda() # 初始化模型 - model = MinamoScoreModule().cuda() + model = MinamoModel().cuda() print_memory("初始化后") # 前向传播 - output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1) + output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val) print_memory("前向传播后") print(f"输入形状: feat={input.shape}") print(f"输出形状: output={output.shape}") + print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}") print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}") print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}") diff --git a/minamo/model/topo.py b/ginka/critic/topo.py similarity index 100% rename from minamo/model/topo.py rename to ginka/critic/topo.py diff --git a/minamo/model/vision.py b/ginka/critic/vision.py similarity index 100% rename from minamo/model/vision.py rename to ginka/critic/vision.py diff --git a/ginka/dataset.py b/ginka/dataset.py index f87f7a2..cbbe4c2 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -87,54 +87,94 @@ class GinkaWGANDataset(Dataset): def __len__(self): return len(self.data) - def handle_stage1(self, target): + def handle_stage1(self, target, tag_cond, val_cond): # 课程学习第一阶段,蒙版填充 removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) - return removed1, masked1, removed2, masked2, removed3, masked3 + return { + "real1": removed1, + "masked1": masked1, + "real2": removed2, + "masked2": masked2, + "real3": removed3, + "masked3": masked3, + "tag_cond": tag_cond, + "val_cond": val_cond + } - def handle_stage2(self, target): + def handle_stage2(self, target, tag_cond, val_cond): # 课程学习第二阶段,完全随机蒙版 removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) # 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可 removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1)) removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1)) - return removed1, masked1, removed2, masked2, removed3, masked3 + return { + "real1": removed1, + "masked1": masked1, + "real2": removed2, + "masked2": masked2, + "real3": removed3, + "masked3": masked3, + "tag_cond": tag_cond, + "val_cond": val_cond + } - def handle_stage3(self, target): + def handle_stage3(self, target, tag_cond, val_cond): # 第三阶段,联合生成,输入随机蒙版 removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) - return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) + + return { + "real1": removed1, + "masked1": masked1, + "real2": removed2, + "masked2": torch.zeros_like(target), + "real3": removed3, + "masked3": torch.zeros_like(target), + "tag_cond": tag_cond, + "val_cond": val_cond + } - def handle_stage4(self, target): - # 第四阶段,与第二阶段交替进行,完全随机输入 + def handle_stage4(self, target, tag_cond, val_cond): + # 第四阶段,完全随机输入 removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) rand = torch.rand(32, 32, 32, device=target.device) - return removed1, rand, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) + + return { + "real1": removed1, + "masked1": rand, + "real2": removed2, + "masked2": torch.zeros_like(target), + "real3": removed3, + "masked3": torch.zeros_like(target), + "tag_cond": tag_cond, + "val_cond": val_cond + } def __getitem__(self, idx): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + tag_cond = torch.FloatTensor(item['tag']) + val_cond = torch.FloatTensor(item['val']) if self.train_stage == 1: - return self.handle_stage1(target) + return self.handle_stage1(target, tag_cond, val_cond) elif self.train_stage == 2: - return self.handle_stage2(target) + return self.handle_stage2(target, tag_cond, val_cond) elif self.train_stage == 3: - return self.handle_stage3(target) + return self.handle_stage3(target, tag_cond, val_cond) elif self.train_stage == 4: - return self.handle_stage4(target) + return self.handle_stage4(target, tag_cond, val_cond) raise RuntimeError(f"Invalid train stage: {self.train_stage}") \ No newline at end of file diff --git a/ginka/model/input.py b/ginka/generator/input.py similarity index 62% rename from ginka/model/input.py rename to ginka/generator/input.py index 2281cc4..da0a37b 100644 --- a/ginka/model/input.py +++ b/ginka/generator/input.py @@ -1,29 +1,34 @@ import torch import torch.nn as nn +from ..common.common import GCNBlock, DoubleConvBlock +from ..common.cond import ConditionInjector class RandomInputHead(nn.Module): def __init__(self): super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(32, 32, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(32), - nn.ELU(), - - nn.Conv2d(32, 64, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(64), - nn.ELU(), - - nn.Conv2d(64, 128, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(128), + self.conv = DoubleConvBlock([32, 64, 128]) + self.gcn = GCNBlock(32, 128, 128, 32, 32) + self.fusion = nn.Sequential( + nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(256), nn.ELU(), ) self.out_conv = nn.Sequential( + nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(128), + nn.ELU(), + nn.AdaptiveMaxPool2d((13, 13)), nn.Conv2d(128, 32, 1), ) + self.inject = ConditionInjector(256, 256) - def forward(self, x): - x = self.conv(x) + def forward(self, x, cond): + x_cnn = self.conv(x) + x_gcn = self.gcn(x) + x = torch.cat([x_cnn, x_gcn], dim=1) + x = self.fusion(x) + x = self.inject(x, cond) x = self.out_conv(x) return x diff --git a/ginka/model/loss.py b/ginka/generator/loss.py similarity index 96% rename from ginka/model/loss.py rename to ginka/generator/loss.py index 828baf2..11bd813 100644 --- a/ginka/model/loss.py +++ b/ginka/generator/loss.py @@ -4,11 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data -from minamo.model.model import MinamoModel from shared.graph import batch_convert_soft_map_to_graph from shared.constant import VISION_WEIGHT, TOPO_WEIGHT -from shared.similarity.topo import overall_similarity, build_topological_graph -from shared.similarity.vision import calculate_visual_similarity +from ..critic.model import MinamoModel CLASS_NUM = 32 ILLEGAL_MAX_NUM = 13 @@ -355,7 +353,7 @@ class WGANGinkaLoss: self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight - def compute_gradient_penalty(self, critic, stage, real_data, fake_data): + def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond): # 进行插值 batch_size = real_data.size(0) epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device) @@ -366,7 +364,7 @@ class WGANGinkaLoss: interp_data.requires_grad_() interp_graph.x.requires_grad_() - _, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage) + _, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond) # 计算梯度 grad_vis = torch.autograd.grad( @@ -392,29 +390,30 @@ class WGANGinkaLoss: return gp_loss def discriminator_loss( - self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor + self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor, + tag_cond: torch.Tensor, val_cond: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ 判别器损失函数 """ fake_data = F.softmax(fake_data, dim=1) real_graph = batch_convert_soft_map_to_graph(real_data) fake_graph = batch_convert_soft_map_to_graph(fake_data) - real_scores, _, _ = critic(real_data, real_graph, stage) - fake_scores, _, _ = critic(fake_data, fake_graph, stage) + real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond) + fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond) # Wasserstein 距离 d_loss = fake_scores.mean() - real_scores.mean() - grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data) + grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data, tag_cond, val_cond) total_loss = d_loss + self.lambda_gp * grad_loss return total_loss, d_loss - def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 生成器损失函数 """ probs_fake = F.softmax(fake, dim=1) fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(probs_fake, fake_graph, stage) + fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) @@ -439,11 +438,11 @@ class WGANGinkaLoss: return sum(losses), minamo_loss, ce_loss, immutable_loss - def generator_loss_total(self, critic, stage, fake) -> torch.Tensor: + def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor: probs_fake = F.softmax(fake, dim=1) fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(probs_fake, fake_graph, stage) + fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) constraint_loss = inner_constraint_loss(probs_fake) @@ -462,11 +461,11 @@ class WGANGinkaLoss: return sum(losses) - def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor: + def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor: probs_fake = F.softmax(fake, dim=1) fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(probs_fake, fake_graph, stage) + fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) diff --git a/ginka/model/model.py b/ginka/generator/model.py similarity index 69% rename from ginka/model/model.py rename to ginka/generator/model.py index 7fddb70..fee210d 100644 --- a/ginka/model/model.py +++ b/ginka/generator/model.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from .unet import GinkaUNet from .output import GinkaOutput from .input import GinkaInput, RandomInputHead +from ..common.cond import ConditionEncoder def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") @@ -14,23 +15,27 @@ class GinkaModel(nn.Module): """ super().__init__() self.head = RandomInputHead() + self.cond = ConditionEncoder(64, 16, 128, 256) self.input = GinkaInput(32, 32, (13, 13), (32, 32)) self.unet = GinkaUNet(32, base_ch, base_ch) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) - def forward(self, x, stage, random=False): + def forward(self, x, stage, tag_cond, val_cond, random=False): + cond = self.cond(tag_cond, val_cond) if random: - x_in = F.softmax(self.head(x), dim=1) + x_in = F.softmax(self.head(x, cond), dim=1) else: x_in = x x = self.input(x_in) - x = self.unet(x) - x = self.output(x, stage) + x = self.unet(x, cond) + x = self.output(x, stage, cond) return x, x_in # 检查显存占用 if __name__ == "__main__": - input = torch.randn((1, 32, 32, 32)).cuda() + input = torch.rand(1, 32, 32, 32).cuda() + tag = torch.rand(1, 64).cuda() + val = torch.rand(1, 16).cuda() # 初始化模型 model = GinkaModel().cuda() @@ -38,12 +43,14 @@ if __name__ == "__main__": print_memory("初始化后") # 前向传播 - output, _ = model(input, 1, True) + output, _ = model(input, 1, tag, val, True) print_memory("前向传播后") print(f"输入形状: feat={input.shape}") print(f"输出形状: output={output.shape}") + print(f"Head parameters: {sum(p.numel() for p in model.head.parameters())}") + print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") diff --git a/ginka/model/output.py b/ginka/generator/output.py similarity index 66% rename from ginka/model/output.py rename to ginka/generator/output.py index f59f8af..05802b6 100644 --- a/ginka/model/output.py +++ b/ginka/generator/output.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn -from .common import GCNBlock, DoubleConvBlock +from ..common.common import GCNBlock, DoubleConvBlock +from ..common.cond import ConditionInjector class StageHead(nn.Module): def __init__(self, in_ch, out_ch, out_size=(13, 13)): @@ -9,15 +10,21 @@ class StageHead(nn.Module): self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32) self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch]) self.pool = nn.Sequential( + nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(in_ch), + nn.ELU(), + nn.AdaptiveMaxPool2d(out_size), nn.Conv2d(in_ch, out_ch, 1) ) + self.inject = ConditionInjector(256, in_ch) - def forward(self, x): + def forward(self, x, cond): x_cnn = self.cnn_head(x) x_gcn = self.gcn_head(x) x = torch.cat([x_cnn, x_gcn], dim=1) x = self.fusion(x) + x = self.inject(x, cond) x = self.pool(x) return x @@ -28,13 +35,13 @@ class GinkaOutput(nn.Module): self.head2 = StageHead(in_ch, out_ch, out_size) self.head3 = StageHead(in_ch, out_ch, out_size) - def forward(self, x, stage): + def forward(self, x, stage, cond): if stage == 1: - x = self.head1(x) + x = self.head1(x, cond) elif stage == 2: - x = self.head2(x) + x = self.head2(x, cond) elif stage == 3: - x = self.head3(x) + x = self.head3(x, cond) else: raise RuntimeError("Unknown generate stage.") return x diff --git a/ginka/model/unet.py b/ginka/generator/unet.py similarity index 78% rename from ginka/model/unet.py rename to ginka/generator/unet.py index 7d5dfa9..057da1e 100644 --- a/ginka/model/unet.py +++ b/ginka/generator/unet.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F from shared.attention import ChannelAttention -from .common import GCNBlock, DoubleConvBlock +from ..common.common import GCNBlock +from ..common.cond import ConditionInjector class GinkaTransformerEncoder(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6): @@ -53,7 +54,7 @@ class ConvBlock(nn.Module): class FusionModule(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() - self.conv = DoubleConvBlock([in_ch, out_ch, out_ch]) + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate') def forward(self, x1, x2): x = torch.cat([x1, x2], dim=1) @@ -66,10 +67,12 @@ class GinkaEncoder(nn.Module): super().__init__() self.conv = ConvBlock(in_ch, out_ch) self.pool = nn.MaxPool2d(2) + self.inject = ConditionInjector(256, out_ch) - def forward(self, x): + def forward(self, x, cond): x = self.conv(x) x = self.pool(x) + x = self.inject(x, cond) return x class GinkaGCNFusedEncoder(nn.Module): @@ -79,12 +82,14 @@ class GinkaGCNFusedEncoder(nn.Module): self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h) self.pool = nn.MaxPool2d(2) self.fusion = FusionModule(out_ch*2, out_ch) + self.inject = ConditionInjector(256, out_ch) - def forward(self, x): + def forward(self, x, cond): x = self.conv(x) x = self.pool(x) x2 = self.gcn(x) x = self.fusion(x, x2) + x = self.inject(x, cond) return x class GinkaUpSample(nn.Module): @@ -105,11 +110,13 @@ class GinkaDecoder(nn.Module): super().__init__() self.upsample = GinkaUpSample(in_ch, in_ch // 2) self.conv = ConvBlock(in_ch, out_ch) + self.inject = ConditionInjector(256, out_ch) - def forward(self, x, feat): + def forward(self, x, feat, cond): x = self.upsample(x) x = torch.cat([x, feat], dim=1) x = self.conv(x) + x = self.inject(x, cond) return x class GinkaGCNFusedDecoder(nn.Module): @@ -119,13 +126,15 @@ class GinkaGCNFusedDecoder(nn.Module): self.conv = ConvBlock(in_ch, out_ch) self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h) self.fusion = FusionModule(out_ch*2, out_ch) + self.inject = ConditionInjector(256, out_ch) - def forward(self, x, feat): + def forward(self, x, feat, cond): x = self.upsample(x) x = torch.cat([x, feat], dim=1) x = self.conv(x) x2 = self.gcn(x) x = self.fusion(x, x2) + x = self.inject(x, cond) return x class GinkaBottleneck(nn.Module): @@ -136,9 +145,10 @@ class GinkaBottleneck(nn.Module): token_size=16, ff_dim=1024, num_layers=4 ) self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4) - self.fusion = FusionModule(module_ch*2, module_ch) + self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) + self.inject = ConditionInjector(256, module_ch) - def forward(self, x): + def forward(self, x, cond): B = x.size(0) x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] @@ -146,7 +156,9 @@ class GinkaBottleneck(nn.Module): x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] x2 = self.gcn(x) - x = self.fusion(x1, x2) + x = torch.cat([x, x1, x2], dim=1) + x = self.fusion(x) + x = self.inject(x, cond) return x @@ -162,7 +174,7 @@ class GinkaUNet(nn.Module): self.down1 = ConvBlock(in_ch, base_ch) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) - self.down4 = GinkaEncoder(base_ch*4, base_ch*8) + self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4) self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4) self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8) @@ -175,17 +187,17 @@ class GinkaUNet(nn.Module): nn.ELU(), ) - def forward(self, x): + def forward(self, x, cond): x1 = self.down1(x) # [B, 64, 32, 32] - x2 = self.down2(x1) # [B, 128, 16, 16] - x3 = self.down3(x2) # [B, 256, 8, 8] - x4 = self.down4(x3) # [B, 512, 4, 4] - x4 = self.bottleneck(x4) # [B, 512, 4, 4] + x2 = self.down2(x1, cond) # [B, 128, 16, 16] + x3 = self.down3(x2, cond) # [B, 256, 8, 8] + x4 = self.down4(x3, cond) # [B, 512, 4, 4] + x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4] # 上采样 - x = self.up1(x4, x3) # [B, 256, 8, 8] - x = self.up2(x, x2) # [B, 128, 16, 16] - x = self.up3(x, x1) # [B, 64, 32, 32] + x = self.up1(x4, x3, cond) # [B, 256, 8, 8] + x = self.up2(x, x2, cond) # [B, 128, 16, 16] + x = self.up3(x, x1, cond) # [B, 64, 32, 32] x = self.final(x) # [B, 32, 32, 32] return x diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 06633ce..3c2e0c3 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -8,13 +8,38 @@ import torch.nn.functional as F import cv2 from torch_geometric.loader import DataLoader from tqdm import tqdm -from .model.model import GinkaModel +from .generator.model import GinkaModel from .dataset import GinkaWGANDataset -from .model.loss import WGANGinkaLoss -from .model.input import RandomInputHead -from minamo.model.model import MinamoScoreModule +from .generator.loss import WGANGinkaLoss +from .generator.input import RandomInputHead +from .critic.model import MinamoModel from shared.image import matrix_to_image_cv +# 标签定义: +# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, +# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具塔 + +# 标量值定义: +# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 +# 1. 怪物密度,怪物数量/地图面积 +# 2. 资源密度,资源数量/地图面积 +# 3. 门密度,门数量/地图面积 +# 4. 入口数量 + +# 图块定义: +# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), +# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 +# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 +# 10-12. 三种等级的红宝石 +# 13-15. 三种等级的蓝宝石 +# 16-18. 三种等级的绿宝石 +# 19-21. 三种等级的血瓶 +# 22-24. 三种等级的道具 +# 25-27. 三种等级的怪物 +# 28-29. 留空 +# 30. 楼梯入口 +# 31. 箭头入口 + BATCH_SIZE = 16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -34,27 +59,28 @@ def parse_arguments(): parser.add_argument("--checkpoint", type=int, default=5) parser.add_argument("--load_optim", type=bool, default=True) parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch + parser.add_argument("--tuning", type=bool, default=False) args = parser.parse_args() return args -def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - fake1, _ = gen(masked1, 1) - fake2, _ = gen(masked2, 2) - fake3, _ = gen(masked3, 3) +def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + fake1, _ = gen(masked1, 1, False, tag, val) + fake2, _ = gen(masked2, 2, False, tag, val) + fake3, _ = gen(masked3, 3, False, tag, val) if detach: return fake1.detach(), fake2.detach(), fake3.detach() else: return fake1, fake2, fake3 -def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: +def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: if progress_detach: - fake1, x_in = gen(input.detach(), 1, random) - fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2) - fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3) + fake1, x_in = gen(input.detach(), 1, random, tag, val) + fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, False, tag, val) + fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, False, tag, val) else: - fake1, x_in = gen(input, 1, random) - fake2, _ = gen(F.softmax(fake1, dim=1), 2) - fake3, _ = gen(F.softmax(fake2, dim=1), 3) + fake1, x_in = gen(input, 1, random, tag, val) + fake2, _ = gen(F.softmax(fake1, dim=1), 2, False, tag, val) + fake3, _ = gen(F.softmax(fake2, dim=1), 3, False, tag, val) if result_detach: return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach() else: @@ -74,7 +100,7 @@ def train(): ginka = GinkaModel().to(device) ginka_head = RandomInputHead().to(device) - minamo = MinamoScoreModule().to(device) + minamo = MinamoModel().to(device) dataset = GinkaWGANDataset(args.train, device) dataset_val = GinkaWGANDataset(args.validate, device) @@ -133,6 +159,14 @@ def train(): print("Train from loaded state.") + curr_epoch = args.curr_epoch + + if args.tuning: + train_stage = 1 + curr_epoch = curr_epoch // 4 + stage_epoch = 0 + mask_ratio = 0.2 + low_loss_epochs = 0 for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm): @@ -142,7 +176,14 @@ def train(): loss_ce_total = torch.Tensor([0]).to(device) for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] + real1 = batch["real1"].to(device) + masked1 = batch["masked1"].to(device) + real2 = batch["real2"].to(device) + masked2 = batch["masked2"].to(device) + real3 = batch["real3"].to(device) + masked3 = batch["masked3"].to(device) + tag_cond = batch["tag_cond"].to(device) + val_cond = batch["val_cond"].to(device) # ---------- 训练判别器 for _ in range(c_steps): @@ -152,10 +193,10 @@ def train(): with torch.no_grad(): if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4) + fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) @@ -235,7 +276,7 @@ def train(): if train_stage == 5: train_stage = 2 - if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= args.curr_epoch: + if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: if mask_ratio >= 0.9: train_stage = 2 mask_ratio += 0.2 @@ -283,13 +324,21 @@ def train(): idx = 0 with torch.no_grad(): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] + real1 = batch["real1"].to(device) + masked1 = batch["masked1"].to(device) + real2 = batch["real2"].to(device) + masked2 = batch["masked2"].to(device) + real3 = batch["real3"].to(device) + masked3 = batch["masked3"].to(device) + tag_cond = batch["tag_cond"].to(device) + val_cond = batch["val_cond"].to(device) + if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) elif train_stage == 3 or train_stage == 4: input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) - fake1, fake2, fake3, _ = gen_total(ginka, input, True, True) + fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4) fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy() diff --git a/minamo/dataset.py b/minamo/dataset.py deleted file mode 100644 index fb99639..0000000 --- a/minamo/dataset.py +++ /dev/null @@ -1,49 +0,0 @@ -import json -import random -import torch -import torch.nn.functional as F -from torch.utils.data import Dataset -from shared.graph import differentiable_convert_to_data -from shared.utils import random_smooth_onehot - -def load_data(path: str): - with open(path, 'r', encoding="utf-8") as f: - data = json.load(f) - - data_list = [] - for value in data["data"].values(): - data_list.append(value) - - return data_list - -class MinamoDataset(Dataset): - def __init__(self, data_path: str): - self.data = load_data(data_path) # 自定义数据加载函数 - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - - map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - - min_main = random.uniform(0.6, 1) - max_main = random.uniform(0.8, 1) - epsilon = random.uniform(0, 0.4) - - map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon) - map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon) - - graph1 = differentiable_convert_to_data(map1_probs) - graph2 = differentiable_convert_to_data(map2_probs) - - return ( - map1_probs, - map2_probs, - torch.FloatTensor([item['visionSimilarity']]), - torch.FloatTensor([item['topoSimilarity']]), - graph1, - graph2 - ) diff --git a/minamo/model/loss.py b/minamo/model/loss.py deleted file mode 100644 index 5fe818f..0000000 --- a/minamo/model/loss.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch.nn as nn -from tqdm import tqdm - -class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0.2, topo_weight=0.8): - super().__init__() - self.vision_weight = vision_weight - self.topo_weight = topo_weight - self.loss = nn.L1Loss() - - def forward(self, vis_pred, topo_pred, vis_true, topo_true): - # print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape) - vis_loss = self.loss(vis_pred, vis_true) - topo_loss = self.loss(topo_pred, topo_true) - # tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f} | {vis_loss.item():.12f}, {topo_loss.item():.12f}") - # print(vis_loss.item(), topo_loss.item()) - return self.vision_weight * vis_loss + self.topo_weight * topo_loss diff --git a/minamo/model/similarity.py b/minamo/model/similarity.py deleted file mode 100644 index 41d3ba3..0000000 --- a/minamo/model/similarity.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.nn import GCNConv, global_mean_pool -from torch_geometric.data import Data - -class MinamoSimilarityVision(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_ch, in_ch * 2, 3, padding=1), - nn.InstanceNorm2d(in_ch * 2), - nn.ReLU(), - - nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1), - nn.InstanceNorm2d(in_ch * 4), - nn.ReLU(), - - nn.Conv2d(in_ch * 4, in_ch * 8, 3), - nn.InstanceNorm2d(in_ch * 8), - nn.ReLU(), - - nn.AdaptiveAvgPool2d(1) - ) - self.fc = nn.Sequential( - nn.Linear(in_ch * 8, out_ch), - ) - - def forward(self, x): - x = self.conv(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - return x - -class MinamoSimilarityTopo(nn.Module): - def __init__(self, in_ch, hidden_dim, out_ch): - super().__init__() - self.input_fc = nn.Sequential( - nn.Linear(in_ch, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.ReLU(), - ) - - self.conv1 = GCNConv(hidden_dim, hidden_dim*2) - self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4) - self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8) - - self.norm1 = nn.LayerNorm(hidden_dim*2) - self.norm2 = nn.LayerNorm(hidden_dim*4) - self.norm3 = nn.LayerNorm(hidden_dim*8) - - self.output_fc = nn.Sequential( - nn.Linear(hidden_dim*8, out_ch) - ) - - def forward(self, graph: Data): - x = self.input_fc(graph.x) - - x = self.conv1(x, graph.edge_index) - x = F.relu(self.norm1(x)) - - x = self.conv2(x, graph.edge_index) - x = F.relu(self.norm2(x)) - - x = self.conv3(x, graph.edge_index) - x = F.relu(self.norm3(x)) - - x = global_mean_pool(x, graph.batch) - x = self.output_fc(x) - - return x - -class MinamoSimilarityModel(nn.Module): - def __init__(self, tile_type=32): - super().__init__() - self.vision = MinamoSimilarityVision(tile_type, 512) - self.topo = MinamoSimilarityTopo(tile_type, 64, 512) - - def forward(self, x, graph): - vis_feat = self.vision(x) - topo_feat = self.topo(graph) - return vis_feat, topo_feat - \ No newline at end of file diff --git a/minamo/train.py b/minamo/train.py deleted file mode 100644 index 91991eb..0000000 --- a/minamo/train.py +++ /dev/null @@ -1,153 +0,0 @@ -import os -import sys -from datetime import datetime -import torch -import torch.optim as optim -import torch.nn.functional as F -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .model.model import MinamoModel -from .model.loss import MinamoLoss -from .dataset import MinamoDataset -from shared.args import parse_arguments - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -os.makedirs("result", exist_ok=True) -os.makedirs("result/minamo_checkpoint", exist_ok=True) -disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm - -def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - - args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json') - - model = MinamoModel(32) - model.to(device) - - # 准备数据集 - dataset = MinamoDataset(args.train) - val_dataset = MinamoDataset(args.validate) - dataloader = DataLoader( - dataset, - batch_size=64, - shuffle=True - ) - val_loader = DataLoader( - val_dataset, - batch_size=64, - shuffle=True - ) - - # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) - criterion = MinamoLoss() - - if args.resume: - data = torch.load(args.from_state, map_location=device) - model.load_state_dict(data["model_state"], strict=False) - if args.load_optim: - optimizer.load_state_dict(data["optimizer_state"]) - print("Train from loaded state.") - - # for name, param in model.named_parameters(): - # if 'ins' not in name: # 仅训练扩展部分 - # param.requires_grad = False - - # 开始训练 - for epoch in tqdm(range(args.epochs), disable=disable_tqdm): - model.train() - total_loss = 0 - - # if epoch == 30: - # for name, param in model.named_parameters(): - # param.requires_grad = True - - for batch in tqdm(dataloader, leave=False, disable=disable_tqdm): - # 数据迁移到设备 - map1, map2, vision_simi, topo_simi, graph1, graph2 = batch - map1 = map1.to(device) # 转为 [B, C, H, W] - map2 = map2.to(device) - topo_simi = topo_simi.to(device) - vision_simi = vision_simi.to(device) - graph1 = graph1.to(device) - graph2 = graph2.to(device) - - if map1.shape[0] == 1: - continue - - # 前向传播 - optimizer.zero_grad() - vision_feat1, topo_feat1 = model(map1, graph1) - vision_feat2, topo_feat2 = model(map2, graph2) - - vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) - - # 计算损失 - loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi) - - # 反向传播 - loss.backward() - optimizer.step() - total_loss += loss.item() - - ave_loss = total_loss / len(dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") - - # total_norm = 0 - # for p in model.parameters(): - # if p.grad is not None: - # param_norm = p.grad.detach().data.norm(2) - # total_norm += param_norm.item() ** 2 - # total_norm = total_norm ** 0.5 - # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 - - # for name, param in model.named_parameters(): - # if param.grad is not None: - # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") - - # 学习率调整 - scheduler.step() - - # 每十轮推理一次验证集 - if (epoch + 1) % 5 == 0: - model.eval() - val_loss = 0 - with torch.no_grad(): - for val_batch in tqdm(val_loader, leave=False, disable=disable_tqdm): - map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch - map1_val = map1_val.to(device) - map2_val = map2_val.to(device) - vision_simi_val = vision_simi_val.to(device) - topo_simi_val = topo_simi_val.to(device) - graph1 = graph1.to(device) - graph2 = graph2.to(device) - - vision_feat1, topo_feat1 = model(map1_val, graph1) - vision_feat2, topo_feat2 = model(map2_val, graph2) - - vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) - - # 计算损失 - loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val) - val_loss += loss_val.item() - - avg_val_loss = val_loss / len(val_loader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") - torch.save({ - "model_state": model.state_dict(), - "optimizer_state": optimizer.state_dict(), - }, f"result/minamo_checkpoint/{epoch + 1}.pth") - - print("Train ended.") - - torch.save({ - "model_state": model.state_dict(), - "optimizer_state": optimizer.state_dict(), - }, "result/minamo.pth") - -if __name__ == "__main__": - torch.set_num_threads(2) - train() diff --git a/minamo/validate.py b/minamo/validate.py deleted file mode 100644 index d9635c8..0000000 --- a/minamo/validate.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -import torch.nn.functional as F -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .model.model import MinamoModel -from .model.loss import MinamoLoss -from .dataset import MinamoDataset - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -def validate(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") - model = MinamoModel(32) - model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) - model.to(device) - - for name, param in model.named_parameters(): - print(f"Layer: {name}, Params: {param.numel()}") - total_params = sum(p.numel() for p in model.parameters()) - print(f"Total parameters: {total_params}") - - # 准备数据集 - val_dataset = MinamoDataset("datasets/minamo-eval-1.json") - val_loader = DataLoader( - val_dataset, - batch_size=32, - shuffle=True - ) - - criterion = MinamoLoss() - - model.eval() - val_loss = 0 - with torch.no_grad(): - for val_batch in tqdm(val_loader): - map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch - map1_val = map1_val.to(device) - map2_val = map2_val.to(device) - vision_simi_val = vision_simi_val.to(device) - topo_simi_val = topo_simi_val.to(device) - graph1 = graph1.to(device) - graph2 = graph2.to(device) - - vision_feat1, topo_feat1 = model(map1_val, graph1) - vision_feat2, topo_feat2 = model(map2_val, graph2) - - vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) - loss_val = criterion( - vision_pred_val, topo_pred_val, - vision_simi_val, topo_simi_val - ) - val_loss += loss_val.item() - - avg_val_loss = val_loss / len(val_loader) - tqdm.write(f"Validation::loss: {avg_val_loss:.6f}") - -if __name__ == "__main__": - torch.set_num_threads(2) - validate() - \ No newline at end of file