diff --git a/ginka/dataset.py b/ginka/dataset.py index c1abe53..0142f0b 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -3,10 +3,18 @@ import random import torch import torch.nn.functional as F from torch.utils.data import Dataset -from minamo.model.model import MinamoModel -from shared.graph import differentiable_convert_to_data +import torch +import torch.nn.functional as F +from typing import List from shared.utils import random_smooth_onehot +STAGE1_MASK = [0, 1, 10, 11] +STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12] +STAGE2_MASK = [6, 7, 8, 9] +STAGE2_REMOVE = [2, 3, 4, 5, 12] +STAGE3_MASK = [2, 3, 4, 5, 12] +STAGE3_REMOVE = [] + def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: data = json.load(f) @@ -23,38 +31,45 @@ def load_minamo_gan_data(data: list): res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True)) return res -class GinkaDataset(Dataset): - def __init__(self, data_path: str, device, minamo: MinamoModel): - self.data = load_data(data_path) # 自定义数据加载函数 - self.max_size = 32 - self.minamo = minamo - self.device = device +def apply_curriculum_mask( + maps: torch.Tensor, # [B, C, H, W] + mask_classes: List[int], # 要遮挡的类别索引 + remove_classes: List[int], # 要移除的类别索引 + mask_ratio: float # 遮挡比例 0~1 +) -> torch.Tensor: + C, H, W = maps.shape + device = maps.device + masked_maps = maps.clone() - def __len__(self): - return len(self.data) + # Step 1: 移除不需要的类别(全设为 0 类) + if remove_classes: + remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0 + masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0 + masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地” + + removed_maps = masked_maps.clone() - def __getitem__(self, idx): - item = self.data[idx] - - target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - min_main = random.uniform(0.75, 0.9) - max_main = random.uniform(0.9, 1) - epsilon = random.uniform(0, 0.25) - target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon) - graph = differentiable_convert_to_data(target_smooth).to(self.device) - target = target.to(self.device) - vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) - - return { - "target_vision_feat": vision_feat, - "target_topo_feat": topo_feat, - "target": target, - } + # Step 2: 对指定类别随机遮挡 + for cls in mask_classes: + cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W] + indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标 + num_mask = int(len(indices) * mask_ratio) + if num_mask > 0: + selected = indices[torch.randperm(len(indices))[:num_mask]] + masked_maps[cls, selected[:, 0], selected[:, 1]] = 0 + masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地” + + return removed_maps, masked_maps class GinkaWGANDataset(Dataset): def __init__(self, data_path: str, device): self.data = load_data(data_path) # 自定义数据加载函数 self.device = device + self.train_stage = 1 + self.mask_ratio1 = 0.1 + self.mask_ratio2 = 0.1 + self.mask_ratio3 = 0.1 + self.random_ratio = 0.0 def __len__(self): return len(self.data) @@ -63,56 +78,20 @@ class GinkaWGANDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - # min_main = random.uniform(0.8, 0.9) - # max_main = random.uniform(0.9, 1) - # epsilon = random.uniform(0, 0.2) - # target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device) + + if self.train_stage == 1: + removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) + removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) + removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) + elif self.train_stage == 2: + removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) + removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9)) + removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9)) - return target - -class MinamoGANDataset(Dataset): - def __init__(self, refer_data_path): - self.refer = load_minamo_gan_data(load_data(refer_data_path)) - self.data = list() - self.data.extend(random.sample(self.refer, 1000)) - - def set_data(self, data: list): - self.data.clear() - self.data.extend(data) - k = min(len(data) / 4, len(self.refer)) - self.data.extend(random.sample(self.refer, int(k))) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - # 假定 map2 是参考地图 - item = self.data[idx] - - map1, map2, vis_sim, topo_sim, review = item - # 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换 - if review: - map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - else: - map1 = torch.FloatTensor(map1) - map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - - min_main = random.uniform(0.75, 0.9) - max_main = random.uniform(0.9, 1) - epsilon = random.uniform(0, 0.25) - - if review: - map1 = random_smooth_onehot(map1, min_main, max_main, epsilon) - map2 = random_smooth_onehot(map2, min_main, max_main, epsilon) - - graph1 = differentiable_convert_to_data(map1) - graph2 = differentiable_convert_to_data(map2) - - return ( - map1, - map2, - torch.FloatTensor([vis_sim]), - torch.FloatTensor([topo_sim]), - graph1, - graph2 - ) \ No newline at end of file + if self.random_ratio > 0: + removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) + removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) + removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) + + return removed1, masked1, removed2, masked2, removed3, masked3 + \ No newline at end of file diff --git a/ginka/model/input.py b/ginka/model/input.py index 7796367..1d13496 100644 --- a/ginka/model/input.py +++ b/ginka/model/input.py @@ -2,43 +2,24 @@ import torch import torch.nn as nn class GinkaInput(nn.Module): - def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)): + def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): super().__init__() + self.out_size = out_size self.fc = nn.Sequential( - nn.Linear(feat_dim, size[0] * size[1] * out_ch), - nn.Unflatten(1, (out_ch, *size)) + nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]), + nn.LayerNorm(out_size[0] * out_size[1]), + nn.ELU() + ) + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(out_ch), + nn.ELU() ) def forward(self, x): + B, C, H, W = x.shape + x = x.view(B, C, H * W) x = self.fc(x) + x = x.view(B, C, self.out_size[0], self.out_size[1]) + x = self.conv(x) return x - -class FeatureEncoder(nn.Module): - def __init__(self, feat_dim, size, mid_ch, out_ch): - super().__init__() - self.encode = nn.Sequential( - nn.Linear(feat_dim, mid_ch * size * size), - nn.Unflatten(1, (mid_ch, size, size)), - nn.Conv2d(mid_ch, out_ch, 1) - ) - - def forward(self, x): - x = self.encode(x) - return x - -class GinkaFeatureInput(nn.Module): - def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64): - super().__init__() - self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch) - self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2) - self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4) - self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8) - self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16) - - def forward(self, x): - x1 = self.encode1(x) - x2 = self.encode2(x) - x3 = self.encode3(x) - x4 = self.encode4(x) - x5 = self.encode5(x) - return x1, x2, x3, x4, x5 diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 0c5e894..ffce4cb 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -13,6 +13,13 @@ from shared.similarity.vision import calculate_visual_similarity CLASS_NUM = 32 ILLEGAL_MAX_NUM = 12 +STAGE_ALLOWED = [ + [], + [0, 1, 10, 11], + [6, 7, 8, 9,], + [2, 3, 4, 5, 12] +] + def get_not_allowed(classes: list[int], include_illegal=False): res = list() for num in range(0, CLASS_NUM): @@ -301,24 +308,47 @@ def js_divergence(p, q, eps=1e-8): return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0) +def immutable_penalty_loss( + pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] +) -> torch.Tensor: + """ + 惩罚模型修改不可更改区域的损失。 + + Args: + input: 模型输出 [B, C, H, W],概率分布 (softmax 后) + target: 原始输入图 [B, C, H, W],概率分布 (softmax 后) + modifiable_classes: 允许被修改的类别列表 + penalty_weight: 对非允许修改区域的惩罚系数 + """ + not_allowed = get_not_allowed(modifiable_classes, include_illegal=True) + input_mask = pred[:, not_allowed, :, :] + with torch.no_grad(): + target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1) + target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() + + # 差异区域(模型试图改变的地方) + penalty = F.cross_entropy(input_mask, target_mask) + + return penalty + class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4): + def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]): + # weight: 判别器损失,L1 损失,不可修改类型损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight - self.diversity_lamda = diversity_lamda - def compute_gradient_penalty(self, critic, real_data, fake_data): + def compute_gradient_penalty(self, critic, stage, real_data, fake_data): # 进行插值 batch_size = real_data.size(0) epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device) - interp_data = interpolate_data(real_data, fake_data, epsilon_data) - interp_graph = batch_convert_soft_map_to_graph(interp_data) + interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device) + interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device) # 对图像进行反向传播并计算梯度 interp_data.requires_grad_() interp_graph.x.requires_grad_() - _, d_vis_score, d_topo_score = critic(interp_data, interp_graph) + _, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage) # 计算梯度 grad_vis = torch.autograd.grad( @@ -344,21 +374,21 @@ class WGANGinkaLoss: return gp_loss def discriminator_loss( - self, critic, real_data: torch.Tensor, - real_graph: torch.Tensor, fake_data: torch.Tensor - ): + self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """ 判别器损失函数 """ + real_graph = batch_convert_soft_map_to_graph(real_data) fake_graph = batch_convert_soft_map_to_graph(fake_data) - real_scores, _, _ = critic(real_data, real_graph) - fake_scores, _, _ = critic(fake_data, fake_graph) - - # print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item()) + real_scores, _, _ = critic(real_data, real_graph, stage) + fake_scores, _, _ = critic(fake_data, fake_graph, stage) # Wasserstein 距离 d_loss = fake_scores.mean() - real_scores.mean() - grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data) + grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data) - return d_loss, d_loss + self.lambda_gp * grad_loss + total_loss = d_loss + self.lambda_gp * grad_loss + + return total_loss, d_loss def calculate_similarity_one(self, map1, map2): topo1 = build_topological_graph(map1) @@ -368,73 +398,29 @@ class WGANGinkaLoss: topo_sim = overall_similarity(topo1, topo2) return vis_sim, topo_sim - - def discriminator_loss_assist(self, critic, fake_data1, fake_data2): - graph1 = batch_convert_soft_map_to_graph(fake_data1) - graph2 = batch_convert_soft_map_to_graph(fake_data2) - vis_feat_1, topo_feat_1 = critic(fake_data1, graph1) - vis_feat_2, topo_feat_2 = critic(fake_data2, graph2) + + def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ 生成器损失函数 """ + fake_graph = batch_convert_soft_map_to_graph(fake) - batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist() - batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist() - - vis_sim_real = [] - topo_sim_real = [] - - for i in range(len(batch1)): - vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i]) - vis_sim_real.append(vis_sim) - topo_sim_real.append(topo_sim) - - vis_sim_real = torch.Tensor(vis_sim_real) - topo_sim_real = torch.Tensor(topo_sim_real) - - pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu() - pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu() - - loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT - - return loss1 - - def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2): - loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1) - loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2) - loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2) - - return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0 - - def generator_loss_one(self, critic, fake, fake_graph): - fake_scores, _, _ = critic(fake, fake_graph) + fake_scores, _, _ = critic(fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) - entrance_loss = entrance_constraint_loss(fake) + ce_loss = F.cross_entropy(fake, real) + immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) + constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) losses = [ minamo_loss * self.weight[0], - class_loss * self.weight[1], - entrance_loss * self.weight[2] + ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 + immutable_loss * self.weight[2], + constraint_loss * self.weight[3] ] - return sum(losses) - - def generator_loss(self, critic, critic_assist, fake1, fake2): - """ 生成器损失函数 """ - fake_graph1 = batch_convert_soft_map_to_graph(fake1) - fake_graph2 = batch_convert_soft_map_to_graph(fake2) + if stage == 1: + # 第一个阶段检查入口存在性 + entrance_loss = entrance_constraint_loss(fake) + losses.append(entrance_loss * self.weight[4]) - loss1 = self.generator_loss_one(critic, fake1, fake_graph1) - loss2 = self.generator_loss_one(critic, fake2, fake_graph2) + # print(losses[2].item()) - # vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1) - # vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2) - - # vis_sim = F.cosine_similarity(vis_feat1, vis_feat2) - # topo_sim = F.cosine_similarity(topo_feat1, topo_feat2) - # similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT - - # print(similarity.mean().item()) - # div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1]) - - return loss1 * 0.5 + loss2 * 0.5\ - # + self.diversity_lamda * F.relu(0.7 - div_loss).mean() - # + self.diversity_lamda * F.relu(similarity - 0.4).mean() + return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss diff --git a/ginka/model/model.py b/ginka/model/model.py index 6ef6cbf..9079cb5 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -9,27 +9,29 @@ def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") class GinkaModel(nn.Module): - def __init__(self, feat_dim=1024, base_ch=64, out_ch=32): + def __init__(self, base_ch=64, out_ch=32): """Ginka Model 模型定义部分 """ super().__init__() - self.unet = GinkaUNet(base_ch, base_ch, feat_dim) + self.input = GinkaInput(32, 32, (13, 13), (32, 32)) + self.unet = GinkaUNet(32, base_ch, base_ch) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) - def forward(self, x): + def forward(self, x, stage): """ Args: x: 参考地图的特征向量 Returns: logits: 输出logits [BS, num_classes, H, W] """ + x = self.input(x) x = self.unet(x) - x = self.output(x) + x = self.output(x, stage) return F.softmax(x, dim=1) # 检查显存占用 if __name__ == "__main__": - feat = torch.randn((1, 1024)).cuda() + input = torch.randn((1, 32, 13, 13)).cuda() # 初始化模型 model = GinkaModel().cuda() @@ -37,14 +39,13 @@ if __name__ == "__main__": print_memory("初始化后") # 前向传播 - output = model(feat) + output = model(input, 1) print_memory("前向传播后") - print(f"输入形状: feat={feat.shape}") + print(f"输入形状: feat={input.shape}") print(f"输出形状: output={output.shape}") - # print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") - # print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}") + print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/model/output.py b/ginka/model/output.py index 11f77ed..d93931c 100644 --- a/ginka/model/output.py +++ b/ginka/model/output.py @@ -1,13 +1,42 @@ import torch import torch.nn as nn -class GinkaOutput(nn.Module): - def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)): +class StageHead(nn.Module): + def __init__(self, in_ch, out_ch, out_size=(13, 13)): super().__init__() - self.conv_down = nn.Sequential( + self.head = nn.Sequential( + nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(in_ch), + nn.ELU(), + + nn.Conv2d(in_ch, in_ch, 1), + nn.InstanceNorm2d(in_ch), + nn.ELU(), + ) + self.pool = nn.Sequential( nn.AdaptiveMaxPool2d(out_size), nn.Conv2d(in_ch, out_ch, 1) ) def forward(self, x): - return self.conv_down(x) + x = self.head(x) + x = self.pool(x) + return x + +class GinkaOutput(nn.Module): + def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)): + super().__init__() + self.head1 = StageHead(in_ch, out_ch, out_size) + self.head2 = StageHead(in_ch, out_ch, out_size) + self.head3 = StageHead(in_ch, out_ch, out_size) + + def forward(self, x, stage): + if stage == 1: + x = self.head1(x) + elif stage == 2: + x = self.head2(x) + elif stage == 3: + x = self.head3(x) + else: + raise RuntimeError("Unknown generate stage.") + return x diff --git a/ginka/model/unet.py b/ginka/model/unet.py index dd1d185..d72211b 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -198,15 +198,15 @@ class GinkaBottleneck(nn.Module): return x class GinkaUNet(nn.Module): - def __init__(self, base_ch=64, out_ch=32, feat_dim=1024): + def __init__(self, in_ch=32, base_ch=64, out_ch=32): """Ginka Model UNet 部分 """ super().__init__() - self.input = GinkaTransformerEncoder( - in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size - token_size=4, ff_dim=feat_dim*2, num_layers=4 - ) - self.down1 = ConvBlock(2, base_ch) + # self.input = GinkaTransformerEncoder( + # in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size + # token_size=4, ff_dim=feat_dim*2, num_layers=4 + # ) + self.down1 = ConvBlock(in_ch, base_ch) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) self.down4 = GinkaEncoder(base_ch*4, base_ch*8) @@ -223,10 +223,6 @@ class GinkaUNet(nn.Module): ) def forward(self, x): - B, D = x.shape # [B, 1024] - x = x.view(B, 4, D // 4) # [B, 4, 256] - x = self.input(x) # [B, 4, 512] - x = x.view(B, 2, 32, 32) # [B, 2, 32, 32] x1 = self.down1(x) # [B, 64, 32, 32] x2 = self.down2(x1) # [B, 128, 16, 16] x3 = self.down3(x2) # [B, 256, 8, 8] @@ -237,5 +233,6 @@ class GinkaUNet(nn.Module): x = self.up1(x4, x3) # [B, 256, 8, 8] x = self.up2(x, x2) # [B, 128, 16, 16] x = self.up3(x, x1) # [B, 64, 32, 32] + x = self.final(x) # [B, 32, 32, 32] - return self.final(x) # [B, 32, 32, 32] + return x diff --git a/ginka/train.py b/ginka/train.py deleted file mode 100644 index 84eca3f..0000000 --- a/ginka/train.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -from datetime import datetime -import torch -import torch.optim as optim -from torch.utils.data import DataLoader -from tqdm import tqdm -from .model.model import GinkaModel -from .model.loss import GinkaLoss -from .dataset import GinkaDataset -from minamo.model.model import MinamoModel -from shared.args import parse_arguments - -BATCH_SIZE = 32 - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -os.makedirs("result", exist_ok=True) -os.makedirs("result/ginka_checkpoint", exist_ok=True) - -def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - - args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json') - - model = GinkaModel() - model.to(device) - minamo = MinamoModel(32) - minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) - minamo.to(device) - minamo.eval() - - # 准备数据集 - dataset = GinkaDataset(args.train, device, minamo) - dataset_val = GinkaDataset(args.validate, device, minamo) - dataloader = DataLoader( - dataset, - batch_size=BATCH_SIZE, - shuffle=True - ) - dataloader_val = DataLoader( - dataset_val, - batch_size=BATCH_SIZE, - shuffle=True - ) - - # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-3) - scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) - criterion = GinkaLoss(minamo) - - if args.resume: - data = torch.load(args.from_state, map_location=device) - model.load_state_dict(data["model_state"], strict=False) - if args.load_optim: - optimizer.load_state_dict(data["optimizer_state"]) - print("Train from loaded state.") - - else: - # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 - criterion.weight[0] = 0.0 - - # 开始训练 - for epoch in tqdm(range(args.epochs)): - model.train() - total_loss = 0 - - # 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来 - if not args.resume and epoch == 10: - criterion.weight[0] = 0.5 - - for batch in dataloader: - # 数据迁移到设备 - target = batch["target"].to(device) - target_vision_feat = batch["target_vision_feat"].to(device) - target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) - # 前向传播 - optimizer.zero_grad() - _, output_softmax = model(feat_vec) - - # 计算损失 - losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) - - # 反向传播 - losses.backward() - optimizer.step() - total_loss += losses.item() - - avg_loss = total_loss / len(dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") - - # for name, param in model.named_parameters(): - # if param.grad is not None: - # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") - - # 学习率调整 - scheduler.step() - - if (epoch + 1) % 5 == 0: - loss_val = 0 - model.eval() - with torch.no_grad(): - for batch in dataloader_val: - # 数据迁移到设备 - target = batch["target"].to(device) - target_vision_feat = batch["target_vision_feat"].to(device) - target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) - - # 前向传播 - output, output_softmax = model(feat_vec) - print(torch.argmax(output, dim=1)[0]) - - # 计算损失 - losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) - loss_val += losses.item() - - avg_val_loss = loss_val / len(dataloader_val) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") - torch.save({ - "model_state": model.state_dict(), - # "optimizer_state": optimizer.state_dict(), - }, f"result/ginka_checkpoint/{epoch + 1}.pth") - - - print("Train ended.") - - torch.save({ - "model_state": model.state_dict(), - # "optimizer_state": optimizer.state_dict(), - }, f"result/ginka.pth") - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_gan.py b/ginka/train_gan.py deleted file mode 100644 index 51a3af8..0000000 --- a/ginka/train_gan.py +++ /dev/null @@ -1,410 +0,0 @@ -import argparse -import socket -import struct -import os -from datetime import datetime -import torch -import torch.optim as optim -import torch.nn.functional as F -from torch_geometric.loader import DataLoader -from tqdm import tqdm -import cv2 -import numpy as np -from .model.model import GinkaModel -from .model.loss import GinkaLoss, WGANGinkaLoss -from .dataset import GinkaDataset, MinamoGANDataset -from minamo.model.model import MinamoModel -from minamo.model.loss import MinamoLoss -from shared.image import matrix_to_image_cv - -BATCH_SIZE = 32 -EPOCHS_GINKA = 5 -EPOCHS_MINAMO = 2 -SOCKET_PATH = "./tmp/ginka_uds" -LOSS_PATH = "result/gan/a-loss.txt" -REPLAY_PATH = "datasets/replay.bin" -VISION_ALPHA = 0 - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -os.makedirs("result", exist_ok=True) -os.makedirs("result/ginka_checkpoint", exist_ok=True) -os.makedirs("result/gan", exist_ok=True) -os.makedirs("tmp", exist_ok=True) - -with open(LOSS_PATH, 'a', encoding='utf-8') as f: - f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n") - -if not os.path.exists(REPLAY_PATH): - with open(REPLAY_PATH, 'wb') as f: - f.write(b'\x00\x00\x00\x00') - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--from_state", type=str, default="result/ginka.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default='ginka-eval.json') - parser.add_argument("--from_cycle", type=int, default=0) - parser.add_argument("--to_cycle", type=int, default=100) - args = parser.parse_args() - return args - -def parse_ginka_batch(batch): - target = batch["target"].to(device) - target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1) - target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device) - - return target, target_vision_feat, target_topo_feat, feat_vec - -def parse_minamo_batch(batch): - map1, map2, vision_simi, topo_simi, graph1, graph2 = batch - map1 = map1.to(device) # 转为 [B, C, H, W] - map2 = map2.to(device) - topo_simi = topo_simi.to(device) - vision_simi = vision_simi.to(device) - graph1 = graph1.to(device) - graph2 = graph2.to(device) - return map1, map2, vision_simi, topo_simi, graph1, graph2 - -def send_all(sock, data): - total_sent = 0 - while total_sent < len(data): - sent = sock.send(data[total_sent:]) - if sent == 0: - raise RuntimeError("Socket connection broken") - total_sent += sent - -def recv_all(sock: socket.socket, length: int): - """循环接收直到获得指定长度的数据""" - data = bytes() - while len(data) < length: - packet = sock.recv(length - len(data)) # 只请求剩余部分 - if not packet: - raise ConnectionError("连接中断") - data += packet - return data - -def parse_minamo_data(sock: socket.socket, maps: np.ndarray): - # 数据通讯 node 输出协议,单位字节: - # 2 - Tensor count; 2 - Review count. Review is right behind train data; - # 1*tc - Compare count for every map tensor delivered. - # 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo; - # N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor. - _, _, H, W = maps.shape - tc_buf = sock.recv(2) - rc_buf = sock.recv(2) - tc = struct.unpack('>h', tc_buf)[0] - rc = struct.unpack('>h', rc_buf)[0] - count_buf = recv_all(sock, 1 * tc) - count: list = struct.unpack(f">{tc}b", count_buf) - N = sum(count) - sim_buf = recv_all(sock, 2 * 4 * (N + rc)) - com_buf = recv_all(sock, N * 1 * H * W) - review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes() - - sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf) - com = struct.unpack(f">{N * 1 * H * W}b", com_buf) - review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list() - - res = list() - flatten_idx = 0 - # 读取当前这一轮生成器的数据 - for idx in range(tc): - com_count = count[idx] - for i in range(com_count): - com_start = flatten_idx * H * W - com_end = (flatten_idx + 1) * H * W - vis_sim = sim[flatten_idx * 2] - topo_sim = sim[flatten_idx * 2 + 1] - com_data = com[com_start:com_end] - flatten_idx += 1 - com_map = np.array(com_data, dtype=np.int8).reshape(H, W) - # map1, map2, vision_similarity, topo_similarity, is_review - res.append((maps[idx], com_map, vis_sim, topo_sim, False)) - - return res - -def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - - args = parse_arguments() - - ginka = GinkaModel() - ginka.to(device) - minamo = MinamoModel(32) - minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) - minamo.to(device) - minamo.eval() - - # 准备数据集 - ginka_dataset = GinkaDataset(args.train, device, minamo) - ginka_dataset_val = GinkaDataset(args.validate, device, minamo) - minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json") - minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json") - ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True) - ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True) - minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True) - minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True) - - # 设定优化器与调度器 - optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) - scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6) - criterion_ginka = GinkaLoss(minamo) - - optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9)) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6) - criterion_minamo = MinamoLoss() - - criterion = WGANGinkaLoss() - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - # 与 node 端通讯 - if os.path.exists(SOCKET_PATH): - os.remove(SOCKET_PATH) - server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - server.bind(SOCKET_PATH) - server.listen(1) - - if args.resume: - data = torch.load(args.from_state, map_location=device) - ginka.load_state_dict(data["model_state"], strict=False) - print("Train from loaded state.") - - print("Waiting for client connection...") - conn, _ = server.accept() - print("Client connected.") - - for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"): - # -------------------- 训练生成器 - for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False): - ginka.train() - minamo.eval() - total_loss = 0 - - for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"): - # 数据迁移到设备 - target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) - # 前向传播 - optimizer_ginka.zero_grad() - _, output_softmax = ginka(feat_vec) - # 计算损失 - losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat) - # 反向传播 - losses.backward() - optimizer_ginka.step() - total_loss += losses.item() - - avg_loss = total_loss / len(ginka_dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") - - # 学习率调整 - scheduler_ginka.step(epoch + 1) - - if (epoch + 1) % 5 == 0: - loss_val = 0 - ginka.eval() - idx = 0 - with torch.no_grad(): - for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"): - target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) - output, output_softmax = ginka(feat_vec) - losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat) - loss_val += losses.item() - if epoch + 1 == EPOCHS_GINKA: - # 最后一次验证的时候顺带生成图片 - map_matrix = torch.argmax(output, dim=1).cpu().numpy() - for matrix in map_matrix: - image = matrix_to_image_cv(matrix, tile_dict) - cv2.imwrite(f"result/ginka_img/{idx}.png", image) - idx += 1 - - avg_val_loss = loss_val / len(ginka_dataloader_val) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") - torch.save({ - "model_state": ginka.state_dict() - }, f"result/ginka_checkpoint/{epoch + 1}.pth") - - # 使用训练集生成 minamo 训练数据,更准确 - gen_list: np.ndarray = np.empty((0, 13, 13), np.int8) - prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32) - with torch.no_grad(): - for batch in ginka_dataloader: - target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) - output, output_softmax = ginka(feat_vec) - prob = output_softmax.cpu().numpy() - prob_list = np.concatenate((prob_list, prob), axis=0) - map_matrix = torch.argmax(output, dim=1).cpu().numpy() - gen_list = np.concatenate((gen_list, map_matrix), axis=0) - - tqdm.write(f"Cycle {cycle} Ginka train ended.") - torch.save({ - "model_state": ginka.state_dict() - }, f"result/gan/ginka-{cycle}.pth") - torch.save({ - "model_state": ginka.state_dict() - }, f"result/ginka.pth") - - # -------------------- 生成 Minamo 的训练数据 - - # 数据通讯 python 输出协议,单位字节: - # 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type. - N, H, W = gen_list.shape - gen_bytes = gen_list.astype(np.int8).tobytes() - buf = bytearray() - buf.extend(struct.pack('>h', N)) # Tensor count - buf.extend(struct.pack('>b', H)) # Map height - buf.extend(struct.pack('>b', W)) # Map width - buf.extend(gen_bytes) # Map tensor - conn.sendall(buf) - data = parse_minamo_data(conn, prob_list) - - vis_sim = 0 - topo_sim = 0 - for _, _, vis, topo, _ in data: - vis_sim += vis - topo_sim += topo - - vis_sim /= len(data) - topo_sim /= len(data) - - with open(LOSS_PATH, 'a', encoding='utf-8') as f: - f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}') - - # 经验回放部分 - with open(REPLAY_PATH, 'r+b') as f: - # 读取文件开头获取总长度 - f.seek(0) - count = struct.unpack('>i', f.read(4))[0] # 取出整数 - if count > 0: - replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False) - - replay_data = np.empty((len(replay), 32, 13, 13)) - for i, n in enumerate(replay): - f.seek(n * 32 * 13 * 13 + 4) - arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13) - replay_data[i] = arr - - map_data: np.ndarray = replay_data.argmax(axis=1) - buf = bytearray() - buf.extend(struct.pack('>h', len(replay))) # Tensor count - buf.extend(struct.pack('>b', H)) # Map height - buf.extend(struct.pack('>b', W)) # Map width - buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor - conn.sendall(buf) - data.extend(parse_minamo_data(conn, replay_data)) - - # 把新的内容写入文件末尾 - to_write = np.random.choice(N, size=min(N, 100), replace=False) - write_data = bytearray() - for n in to_write: - write_data.extend(prob_list[n].tobytes()) - - f.seek(0, 2) # 定位到文件末尾 - f.write(write_data) - - f.seek(0) # 定位到文件开头 - f.write(struct.pack('>i', count + len(to_write))) - f.flush() # 确保数据被刷新到磁盘 - - minamo_dataset.set_data(data) - - # -------------------- 训练判别器 - for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"): - ginka.eval() - minamo.train() - total_loss = 0 - - for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"): - map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch) - batch_size = map1.shape[0] - - if batch_size == 1: - continue - - # 前向传播 - optimizer_minamo.zero_grad() - vis_feat_real, topo_feat_real = minamo(map1, graph1) - vis_feat_ref, topo_feat_ref = minamo(map2, graph2) - - # 生成假数据 - with torch.no_grad(): - fake_feat = torch.randn((batch_size, 1024), device=device) - fake_data = ginka(fake_feat) - - # 创建插值样本 - alpha = torch.rand((batch_size, 1, 1, 1), device=device) - interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True) - - vis_feat_fake, topo_feat_fake = minamo(fake_data) - vis_feat_interp, topo_feat_interp = minamo(interpolates) - - vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1) - topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1) - vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1) - topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1) - vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1) - topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1) - - # 计算相似度 - score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA) - score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA) - score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA) - - # 计算损失 - loss = criterion.discriminator_loss(score_real, score_fake, score_interp) - - # 反向传播 - loss.backward() - optimizer_minamo.step() - total_loss += loss.item() - - ave_loss = total_loss / len(minamo_dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}") - - scheduler_minamo.step(epoch + 1) - - # 每十轮推理一次验证集 - if epoch + 1 == EPOCHS_MINAMO: - minamo.eval() - val_loss = 0 - with torch.no_grad(): - for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"): - map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch) - - vis_feat_real, topo_feat_real = minamo(map1_val, graph1) - vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2) - - vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1) - topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1) - - # 计算损失 - loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val) - val_loss += loss_val.item() - - avg_val_loss = val_loss / len(minamo_dataloader_val) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") - torch.save({ - "model_state": minamo.state_dict() - }, f"result/minamo_checkpoint/{epoch + 1}.pth") - - tqdm.write(f"Cycle {cycle} Minamo train ended.") - torch.save({ - "model_state": minamo.state_dict() - }, f"result/gan/minamo-{cycle}.pth") - torch.save({ - "model_state": minamo.state_dict() - }, f"result/minamo.pth") - with open(LOSS_PATH, 'a', encoding='utf-8') as f: - f.write(f' | Minamo: {avg_val_loss:.12f}\n') - - print("Train ended.") - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 1b5cbec..2410853 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -16,7 +16,7 @@ from shared.graph import batch_convert_soft_map_to_graph from shared.image import matrix_to_image_cv from shared.constant import VISION_WEIGHT, TOPO_WEIGHT -BATCH_SIZE = 32 +BATCH_SIZE = 16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -30,39 +30,56 @@ def parse_arguments(): parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth") parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth") parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--checkpoint", type=int, default=5) parser.add_argument("--load_optim", type=bool, default=True) args = parser.parse_args() return args -def clip_weights(model, clip_value=0.01): - for param in model.parameters(): - param.data = torch.clamp(param.data, -clip_value, clip_value) +def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + fake1: torch.Tensor = gen(masked1, 1) + fake2: torch.Tensor = gen(masked2, 2) + fake3: torch.Tensor = gen(masked3, 3) + if detach: + return fake1.detach(), fake2.detach(), fake3.detach() + else: + return fake1, fake2, fake3 + +def gen_total(gen, input, detach=False) -> torch.Tensor: + fake1 = gen(input, 1) + fake2 = gen(fake1, 2) + fake3 = gen(fake2, 3) + if detach: + return fake3.detach() + else: + return fake3 def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") args = parse_arguments() - # c_steps = 1 if args.resume else 5 - # g_steps = 5 if args.resume else 1 c_steps = 5 g_steps = 1 + # 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段 + # 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段 + train_stage = 1 + mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 + random_ratio = 0 ginka = GinkaModel() minamo = MinamoScoreModule() - minamo_sim = MinamoSimilarityModel() ginka.to(device) minamo.to(device) - minamo_sim.to(device) dataset = GinkaWGANDataset(args.train, device) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) + dataset_val = GinkaWGANDataset(args.validate, device) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) - optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4) # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) # scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs) @@ -82,87 +99,134 @@ def train(): ginka.load_state_dict(data_ginka["model_state"], strict=False) minamo.load_state_dict(data_minamo["model_state"], strict=False) - if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None: + if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None: c_steps = data_ginka["c_steps"] g_steps = data_ginka["g_steps"] + if data_ginka.get("mask_ratio") is not None: + mask_ratio = data_ginka["mask_ratio"] + + if data_ginka.get("random_ratio") is not None: + random_ratio = data_ginka["random_ratio"] + + if data_ginka.get("stage") is not None: + train_stage = data_ginka["stage"] + if args.load_optim: - if data_ginka["optim_state"] is not None: + if data_ginka.get("optim_state") is not None: optimizer_ginka.load_state_dict(data_ginka["optim_state"]) - if data_minamo["optim_state"] is not None: + if data_minamo.get("optim_state") is not None: optimizer_minamo.load_state_dict(data_minamo["optim_state"]) - if data_minamo["optim_state_sim"] is not None: - optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"]) + + dataset.train_stage = train_stage + dataset.mask_ratio1 = mask_ratio + dataset.mask_ratio2 = mask_ratio + dataset.mask_ratio3 = mask_ratio + dataset.random_ratio = random_ratio + + dataset_val.train_stage = train_stage + dataset_val.mask_ratio1 = mask_ratio + dataset_val.mask_ratio2 = mask_ratio + dataset_val.mask_ratio3 = mask_ratio + dataset_val.random_ratio = random_ratio print("Train from loaded state.") + low_loss_epochs = 0 + for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm): loss_total_minamo = torch.Tensor([0]).to(device) - loss_total_minamo_sim = torch.Tensor([0]).to(device) loss_total_ginka = torch.Tensor([0]).to(device) dis_total = torch.Tensor([0]).to(device) + loss_ce_total = torch.Tensor([0]).to(device) - for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - batch_size = real_data.size(0) - real_data = real_data.to(device) - real_graph = batch_convert_soft_map_to_graph(real_data) + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] # ---------- 训练判别器 for _ in range(c_steps): # 生成假样本 optimizer_minamo.zero_grad() - z = torch.rand(batch_size, 1024, device=device) - fake_data = ginka(z) - fake_data = fake_data.detach() - - # 计算判别器输出 - # 反向传播 - dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data) - loss_d.backward() - # torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0) - # total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2) - # print("Critic 梯度范数:", total_norm.item()) - # print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item()) - # print("Critic 输出范围:", d_real.min().item(), d_real.max().item()) + optimizer_ginka.zero_grad() + if train_stage == 1 or train_stage == 2: + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) + + loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) + loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) + loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3) + + dis_avg = (dis1 + dis2 + dis3) / 3.0 + loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 + + # 反向传播 + loss_d_avg.backward() + elif train_stage == 3: + pass + optimizer_minamo.step() - loss_total_minamo += loss_d.detach() - dis_total += dis.detach() + loss_total_minamo += loss_d_avg.detach() + dis_total += dis_avg.detach() # ---------- 训练生成器 for _ in range(g_steps): + optimizer_minamo.zero_grad() optimizer_ginka.zero_grad() - # optimizer_minamo_sim.zero_grad() - - z1 = torch.randn(batch_size, 1024, device=device) - z2 = torch.randn(batch_size, 1024, device=device) - fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2) - - # 先训练辅助判别器 - # loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2) - # loss_c_assist.backward(retain_graph=True) - # optimizer_minamo_sim.step() - - loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2) - loss_g.backward() - optimizer_ginka.step() - - loss_total_ginka += loss_g - # loss_total_minamo_sim += loss_c_assist.detach() - # tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}") + if train_stage == 1 or train_stage == 2: + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False) + + loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1) + loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2) + loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3) + + loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 + loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3) + + loss_g.backward() + optimizer_ginka.step() + loss_total_ginka += loss_g.detach() + loss_ce_total += loss_ce.detach() + + elif train_stage == 3: + pass avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps - avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps + avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps avg_dis = dis_total.item() / len(dataloader) / c_steps tqdm.write( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\ - f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\ - f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\ - f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}" + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " + + f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " + + f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}" ) + if avg_loss_ce < 0.5: + low_loss_epochs += 1 + else: + low_loss_epochs = 0 + + if low_loss_epochs >= 5 and train_stage == 2: + random_ratio += 0.1 + random_ratio = min(random_ratio, 0.5) + low_loss_epochs = 0 + + if low_loss_epochs >= 5 and train_stage == 1: + if mask_ratio >= 0.9: + train_stage = 2 + + mask_ratio += 0.1 + mask_ratio = min(mask_ratio, 0.9) + low_loss_epochs = 0 + + dataset.train_stage = 2 + dataset_val.train_stage = 2 + dataset.random_ratio = random_ratio + dataset_val.random_ratio = random_ratio + dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio + dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio + # scheduler_ginka.step() # scheduler_minamo.step() @@ -172,38 +236,44 @@ def train(): g_steps = 1 if avg_loss_ginka > 0 or avg_loss_minamo > 0: - c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15) + c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1)) else: c_steps = 5 # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % args.checkpoint == 0: - # 输出 20 张图片,每批次 4 张,一共五批 - idx = 0 - with torch.no_grad(): - for _ in range(5): - z = torch.randn(4, 1024, device=device) - output = ginka(z) - - map_matrix = torch.argmax(output, dim=1).cpu().numpy() - for matrix in map_matrix: - image = matrix_to_image_cv(matrix, tile_dict) - cv2.imwrite(f"result/ginka_img/{idx}.png", image) - idx += 1 - # 保存检查点 torch.save({ "model_state": ginka.state_dict(), "optim_state": optimizer_ginka.state_dict(), "c_steps": c_steps, - "g_steps": g_steps + "g_steps": g_steps, + "stage": train_stage, + "mask_ratio": mask_ratio, + "random_ratio": random_ratio, }, f"result/wgan/ginka-{epoch + 1}.pth") torch.save({ "model_state": minamo.state_dict(), - "model_state_sim": minamo_sim.state_dict(), - "optim_state": optimizer_minamo.state_dict(), - "optim_state_sim": optimizer_minamo_sim.state_dict() + "optim_state": optimizer_minamo.state_dict() }, f"result/wgan/minamo-{epoch + 1}.pth") + + idx = 0 + with torch.no_grad(): + for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] + if train_stage == 1: + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) + fake1 = torch.argmax(fake1, dim=1).cpu().numpy() + fake2 = torch.argmax(fake2, dim=1).cpu().numpy() + fake3 = torch.argmax(fake3, dim=1).cpu().numpy() + + for i in range(fake1.shape[0]): + for key, one in enumerate([fake1, fake2, fake3]): + map_matrix = one[i] + image = matrix_to_image_cv(map_matrix, tile_dict) + cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image) + + idx += 1 print("Train ended.") torch.save({ @@ -211,7 +281,6 @@ def train(): }, f"result/ginka.pth") torch.save({ "model_state": minamo.state_dict(), - "model_state_sim": minamo_sim.state_dict(), }, f"result/minamo.pth") if __name__ == "__main__": diff --git a/minamo/model/model.py b/minamo/model/model.py index 5d245e0..209f59c 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -20,23 +20,41 @@ class MinamoModel(nn.Module): return vision_feat, topo_feat +class MinamoScoreHead(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.vision_fc = nn.Sequential( + spectral_norm(nn.Linear(in_dim, out_dim)), + ) + self.topo_fc = nn.Sequential( + spectral_norm(nn.Linear(in_dim, out_dim)) + ) + + def forward(self, vis_feat, topo_feat): + vis_score = self.vision_fc(vis_feat) + topo_score = self.topo_fc(topo_feat) + return vis_score, topo_score + class MinamoScoreModule(nn.Module): def __init__(self, tile_types=32): super().__init__() self.topo_model = MinamoTopoModel(tile_types) self.vision_model = MinamoVisionModel(tile_types) # 输出层 - self.topo_fc = nn.Sequential( - spectral_norm(nn.Linear(512, 1)), - ) - self.vision_fc = nn.Sequential( - spectral_norm(nn.Linear(512, 1)), - ) + self.head1 = MinamoScoreHead(512, 1) + self.head2 = MinamoScoreHead(512, 1) + self.head3 = MinamoScoreHead(512, 1) - def forward(self, map, graph): - topo_feat = self.topo_model(graph) - topo_score = self.topo_fc(topo_feat) + def forward(self, map, graph, stage): vision_feat = self.vision_model(map) - vision_score = self.vision_fc(vision_feat) + topo_feat = self.topo_model(graph) + if stage == 1: + vision_score, topo_score = self.head1(vision_feat, topo_feat) + elif stage == 2: + vision_score, topo_score = self.head2(vision_feat, topo_feat) + elif stage == 3: + vision_score, topo_score = self.head3(vision_feat, topo_feat) + else: + raise RuntimeError("Unknown critic stage.") score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score return score, vision_score, topo_score diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..de59a05 --- /dev/null +++ b/train.sh @@ -0,0 +1,4 @@ +# 从头训练 +python3 -u -m ginka.train_wgan >> output.log +# 接续训练 +python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log \ No newline at end of file