diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 517b429..45739ba 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -28,13 +28,13 @@ class GinkaMaskGIT(nn.Module): self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) # 剩余密度投影:将 5 个浮点数 [wall, door, monster, entrance, resource] 投影为 d_z 维 token - self.remain_proj = nn.Linear(5, d_z) + self.remain_proj = nn.Linear(1, d_z) # z 投影:逐 token 线性变换,保持序列结构 self.z_proj = nn.Linear(d_z, d_z) # 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个剩余密度 token - self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model) + self.cond_proj = nn.Linear((z_seq_len + 2 + 5) * d_z, d_model) # 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层 self.transformer = Transformer( @@ -62,7 +62,7 @@ class GinkaMaskGIT(nn.Module): ], dim=1) # 剩余密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z] - e_remain = self.remain_proj(remain).unsqueeze(1) + e_remain = self.remain_proj(remain.unsqueeze(-1)) # z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z] z_proj = self.z_proj(z) diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index a489cb6..c712938 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -23,8 +23,8 @@ from shared.image import matrix_to_image_cv # # 整体架构: # VQ-VAE(三组独立编码器 vq1/vq2/vq3)将三阶段地图上下文分别编码为离散潜变量, -# 再由共用 VectorQuantizer 统一量化为 z_q; -# 三个独立 MaskGIT(mg1/mg2/mg3)分别以 z_q 和 struct_inject 为条件, +# 再由三个独立 VectorQuantizer 分别量化为 z_q1/z_q2/z_q3; +# 三个独立 MaskGIT(mg1/mg2/mg3)分别以各自阶段 z_q 和 struct_inject 为条件, # 逐阶段迭代解码地图图块序列。 # # 三阶段生成目标: @@ -37,14 +37,14 @@ from shared.image import matrix_to_image_cv # 共用 VQ-VAE 超参 # 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码 -VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3) -VQ_K = 8 # codebook 大小(离散码本条目数) +VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3) +VQ_K = 16 # codebook 大小(离散码本条目数) VQ_D_Z = 64 # 码字维度 -VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook) +VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook) VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用) -VQ_LAYERS = 3 # VQ-VAE Transformer 层数 -VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度 -VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度 +VQ_LAYERS = 6 # VQ-VAE Transformer 层数 +VQ_DIM_FF = 1024 # VQ-VAE 前馈网络隐层维度 +VQ_D_MODEL = 256 # VQ-VAE Transformer 模型维度 VQ_NHEAD = 4 # VQ-VAE 多头注意力头数 # 第一阶段 MaskGIT 超参 @@ -65,10 +65,10 @@ STAGE3_MG_NHEAD = 4 STAGE3_MG_NUM_LAYERS = 6 STAGE3_MG_DIM_FF = 1024 -# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例) -STAGE1_FOCAL_WEIGHT = 1.0 -STAGE2_FOCAL_WEIGHT = 1.0 -STAGE3_FOCAL_WEIGHT = 1.0 +# 三阶段 Cross Entropy 损失权重(可调节各阶段对总损失的贡献比例) +STAGE1_CE_WEIGHT = 1.0 +STAGE2_CE_WEIGHT = 1.0 +STAGE3_CE_WEIGHT = 1.0 # 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制) STAGE1_VQ_WEIGHT = 0.5 @@ -95,7 +95,6 @@ MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率 MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率 # 损失参数 -FOCAL_GAMMA = 2.0 # Focal Loss 参数 VQ_BETA = 0.5 # 承诺损失权重 # 训练超参 @@ -105,6 +104,7 @@ MIN_LR = 1e-6 # 余弦退火最低学习率 WEIGHT_DECAY = 1e-4 # L2 正则化系数 EPOCHS = 400 # 总训练轮数 CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证 +REFERENCE_SAMPLE_PROB = 0.2 # 训练时将参考掩码图无梯度自采样 1-3 步的概率 device = torch.device( "cuda:0" if torch.cuda.is_available() @@ -131,7 +131,7 @@ def parse_arguments(): def build_model(device: torch.device): # 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3) - # 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer + # 输出形状均为 [B, L, d_z],分别送入各自阶段的 quantizer vq_kwargs = dict( num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL, nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_h=MAP_H, map_w=MAP_W @@ -140,21 +140,21 @@ def build_model(device: torch.device): vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance) vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文(resource) - # 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件 + # 三个独立 MaskGIT 解码器,分别接收各自阶段的 z_q 作为条件 mg1 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF, nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W, - z_seq_len=VQ_L * 3 + z_seq_len=VQ_L ).to(device) mg2 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF, nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W, - z_seq_len=VQ_L * 3 + z_seq_len=VQ_L ).to(device) mg3 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF, nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W, - z_seq_len=VQ_L * 3 + z_seq_len=VQ_L ).to(device) # 六个模型参数合并到同一优化器,端到端联合训练 @@ -166,18 +166,106 @@ def build_model(device: torch.device): # 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR) - # 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表 - quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) + # 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰 + quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) + quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) + quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) + quantizers = (quantizer1, quantizer2, quantizer3) - return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler + return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler -def focal_loss(logits, target): +def cross_entropy_loss(logits, target): # logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式 - ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none') - pt = torch.exp(-ce) # pt = 模型对正确类的预测概率 - # Focal Loss:对高置信度样本降低权重,让模型更专注于难样本 - focal = ((1 - pt) ** FOCAL_GAMMA) * ce - return focal.mean() + return F.cross_entropy(logits.permute(0, 2, 1), target) + +def summarize_codebook_hits(code_hits: torch.Tensor) -> tuple[float, float, int]: + total_hits = code_hits.sum() + if total_hits.item() <= 0: + return 0.0, 0.0, 0 + + probs = code_hits / total_hits + perplexity = torch.exp( + -(probs * torch.log(probs.clamp_min(1e-10))).sum() + ).item() + active_codes = int((code_hits > 0).sum().item()) + usage_rate = active_codes / code_hits.numel() + return perplexity, usage_rate, active_codes + +def quantize_stage_latents( + quantizers: tuple[VectorQuantizer, VectorQuantizer, VectorQuantizer], + z_e1: torch.Tensor, + z_e2: torch.Tensor, + z_e3: torch.Tensor +) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + quantizer1, quantizer2, quantizer3 = quantizers + z_q1, _, commit_loss1, _, code_hits1 = quantizer1(z_e1) + z_q2, _, commit_loss2, _, code_hits2 = quantizer2(z_e2) + z_q3, _, commit_loss3, _, code_hits3 = quantizer3(z_e3) + + commit_loss = (commit_loss1 + commit_loss2 + commit_loss3) / 3 + code_hits = torch.stack([code_hits1, code_hits2, code_hits3], dim=0) + return (z_q1, z_q2, z_q3), commit_loss, code_hits + +def build_reference_rollout_steps(prob: float) -> int: + if random.random() >= prob: + return 0 + + return random.randint(1, 3) + +def sample_reference_inputs( + model: torch.nn.Module, + reference: torch.Tensor, + z_q: torch.Tensor, + struct: torch.Tensor, + target_density: torch.Tensor, + stage: int, + rollout_steps: int +) -> torch.Tensor: + if rollout_steps <= 0: + return reference + + sampled_reference = reference.clone() + with torch.no_grad(): + current = sampled_reference.clone() + z_q_detached = z_q.detach() + + for _ in range(rollout_steps): + masked_positions = current == MASK_TOKEN + masked_counts = masked_positions.sum(dim=1) + if int(masked_counts.sum().item()) <= 0: + break + + remain = compute_remaining(current, target_density, stage) + logits = model(current, z_q_detached, struct, remain) + probs = F.softmax(logits, dim=-1) + dist = torch.distributions.Categorical(probs) + predicted = dist.sample() + confidence = torch.gather( + probs, + -1, + predicted.unsqueeze(-1) + ).squeeze(-1) + + for local_idx in range(current.size(0)): + masked_count = int(masked_counts[local_idx].item()) + if masked_count <= 0: + continue + + masked_indices = masked_positions[local_idx].nonzero(as_tuple=True)[0] + reveal_count = max(1, math.ceil(masked_count * 0.1)) + reveal_count = min(reveal_count, masked_indices.numel()) + masked_confidence = confidence[local_idx, masked_indices] + _, reveal_order = torch.topk( + masked_confidence, + k=reveal_count, + largest=True + ) + reveal_indices = masked_indices[reveal_order] + current[local_idx, reveal_indices] = predicted[local_idx, reveal_indices] + + sampled_reference = current + + return sampled_reference def random_struct(device: torch.device) -> torch.Tensor: # 随机采样一组结构参量,用于无条件自由生成 @@ -336,14 +424,17 @@ def full_generate_random_z( device: torch.device, keep_fixed: tuple[bool, bool, bool] = (True, True, True) ) -> tuple: - vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models + quantizer1, quantizer2, quantizer3 = quantizers with torch.no_grad(): - z = quantizer.sample(1, VQ_L, device) + z1 = quantizer1.sample(1, VQ_L, device) + z2 = quantizer2.sample(1, VQ_L, device) + z3 = quantizer3.sample(1, VQ_L, device) # stage1:生成墙壁骨架 pred1_np = maskgit_sample( - mg1, input.clone(), z, struct, target_density, 1, + mg1, input.clone(), z1, struct, target_density, 1, GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0] ) inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) @@ -351,7 +442,7 @@ def full_generate_random_z( # stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并 pred2_np = maskgit_sample( - mg2, inp2, z, struct, target_density, 2, + mg2, inp2, z2, struct, target_density, 2, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] ) merged12 = pred1_np.copy() @@ -361,7 +452,7 @@ def full_generate_random_z( # stage3:填充 resource(3) pred3_np = maskgit_sample( - mg3, inp3, z, struct, target_density, 3, + mg3, inp3, z3, struct, target_density, 3, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] ) merged123 = merged12.copy() @@ -371,26 +462,28 @@ def full_generate_random_z( def full_generate_specific_z( input: torch.Tensor, - z: torch.Tensor, + z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor], struct: torch.Tensor, target_density: torch.Tensor, models: list[torch.nn.Module], device: torch.device, keep_fixed: tuple[bool, bool, bool] = (True, True, True) ) -> tuple: - vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models with torch.no_grad(): + z1, z2, z3 = z_q + # 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z pred1_np = maskgit_sample( - mg1, input.clone(), z, struct, target_density, 1, + mg1, input.clone(), z1, struct, target_density, 1, GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0] ) inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN pred2_np = maskgit_sample( - mg2, inp2, z, struct, target_density, 2, + mg2, inp2, z2, struct, target_density, 2, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] ) merged12 = pred1_np.copy() @@ -399,7 +492,7 @@ def full_generate_specific_z( inp3[inp3 == 0] = MASK_TOKEN pred3_np = maskgit_sample( - mg3, inp3, z, struct, target_density, 3, + mg3, inp3, z3, struct, target_density, 3, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] ) merged123 = merged12.copy() @@ -493,9 +586,10 @@ def visualize_part2(batch, z_q, models, device, tile_dict): inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) struct_t = batch["struct_inject"][0:1].to(device) target_density_t = batch["target_density"][0:1].to(device) + z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1]) kf = rand_keep() auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z( - inp1_t, z_q[0:1], struct_t, target_density_t, models, device, keep_fixed=kf + inp1_t, z_q_single, struct_t, target_density_t, models, device, keep_fixed=kf ) kf_label = 'fix' if kf[0] else 'free' @@ -665,6 +759,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict): struct_t = batch["struct_inject"][0:1].to(device) struct_cpu = batch["struct_inject"][0] base_target_density = batch["target_density"][0:1].to(device) + z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1]) ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) gen_imgs = [] wall_count_values = [20, 35, 50, 65, 80] @@ -673,7 +768,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict): fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE target_density_cpu = fixed_target_density[0].cpu() _, _, merged123 = full_generate_specific_z( - inp1_t, z_q[0:1], struct_t, fixed_target_density, models, device + inp1_t, z_q_single, struct_t, fixed_target_density, models, device ) gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu)) row1 = [to_img(ref_np)] + gen_imgs[:2] @@ -695,7 +790,8 @@ def validate( density_stats: dict, epoch: int ): - vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models + quantizer1, quantizer2, quantizer3 = quantizers # 切换为推理模式(关闭 Dropout / BatchNorm 统计更新) for m in [vq1, vq2, vq3, mg1, mg2, mg3]: @@ -706,6 +802,7 @@ def validate( loss2_total = torch.Tensor([0]).to(device) loss3_total = torch.Tensor([0]).to(device) commit_total = torch.Tensor([0]).to(device) + code_hits_total = torch.zeros(3, quantizer1.K, device=device) density_metrics = { 1: {"mae": 0.0, "over": 0.0, "count": 0}, @@ -736,27 +833,30 @@ def validate( struct = batch["struct_inject"].to(device) target_density = batch["target_density"].to(device) - # VQ 编码:各阶段独立编码后拼接、量化 + # VQ 编码:各阶段独立编码并分别量化 z_e1 = vq1(enc1) # [B, L, d_z] z_e2 = vq2(enc2) z_e3 = vq3(enc3) - z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z] - z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z] + z_q, commit_loss, code_hits = quantize_stage_latents( + quantizers, z_e1, z_e2, z_e3 + ) + z_q1, z_q2, z_q3 = z_q remain1 = compute_remaining(inp1, target_density, 1) remain2 = compute_remaining(inp2, target_density, 2) remain3 = compute_remaining(inp3, target_density, 3) - # 三阶段 MaskGIT 推理(均以完整 z_q、struct 和动态 remain 为条件) - logits1 = mg1(inp1, z_q, struct, remain1) - logits2 = mg2(inp2, z_q, struct, remain2) - logits3 = mg3(inp3, z_q, struct, remain3) + # 三阶段 MaskGIT 推理:各阶段只接收自己的 z_q + logits1 = mg1(inp1, z_q1, struct, remain1) + logits2 = mg2(inp2, z_q2, struct, remain2) + logits3 = mg3(inp3, z_q3, struct, remain3) - loss1_total += focal_loss(logits1, target1) - loss2_total += focal_loss(logits2, target2) - loss3_total += focal_loss(logits3, target3) + loss1_total += cross_entropy_loss(logits1, target1) + loss2_total += cross_entropy_loss(logits2, target2) + loss3_total += cross_entropy_loss(logits3, target3) commit_total += commit_loss + code_hits_total += code_hits # 计算各目标对象的真实密度误差与过量生成密度 pred1_map = torch.argmax(logits1, dim=-1).cpu() @@ -806,19 +906,20 @@ def validate( for m in [vq1, vq2, vq3, mg1, mg2, mg3]: m.train() - return loss1_total, loss2_total, loss3_total, commit_total + return loss1_total, loss2_total, loss3_total, commit_total, code_hits_total def train(device: torch.device): args = parse_arguments() models = build_model(device) - vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models + quantizer1, quantizer2, quantizer3 = quantizers tqdm.write(f"Device: {device}") model_list = [ ("vq1", vq1), ("vq2", vq2), ("vq3", vq3), ("mg1", mg1), ("mg2", mg2), ("mg3", mg3), - ("quantizer", quantizer) + ("quantizer1", quantizer1), ("quantizer2", quantizer2), ("quantizer3", quantizer3) ] total_params = 0 for name, m in model_list: @@ -838,7 +939,15 @@ def train(device: torch.device): mg1.load_state_dict(ckpt["mg1"]) mg2.load_state_dict(ckpt["mg2"]) mg3.load_state_dict(ckpt["mg3"]) - quantizer.load_state_dict(ckpt["quantizer"]) + if "quantizer1" in ckpt: + quantizer1.load_state_dict(ckpt["quantizer1"]) + quantizer2.load_state_dict(ckpt["quantizer2"]) + quantizer3.load_state_dict(ckpt["quantizer3"]) + elif "quantizer" in ckpt: + quantizer1.load_state_dict(ckpt["quantizer"]) + quantizer2.load_state_dict(ckpt["quantizer"]) + quantizer3.load_state_dict(ckpt["quantizer"]) + tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3") # load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练) if args.load_optim and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"]) @@ -878,6 +987,7 @@ def train(device: torch.device): loss2_total = torch.Tensor([0]).to(device) loss3_total = torch.Tensor([0]).to(device) commit_total = torch.Tensor([0]).to(device) + code_hits_total = torch.zeros(3, quantizer1.K, device=device) for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): # 三阶段各自的掩码输入序列、预测目标和编码器上下文 @@ -904,26 +1014,39 @@ def train(device: torch.device): z_e2 = vq2(enc2) z_e3 = vq3(enc3) - # 合并三阶段编码后量化 - z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z] - z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z] + # 三阶段分别量化,各自使用独立 codebook + z_q, commit_loss, code_hits = quantize_stage_latents( + quantizers, z_e1, z_e2, z_e3 + ) + z_q1, z_q2, z_q3 = z_q + + rollout_steps = build_reference_rollout_steps(REFERENCE_SAMPLE_PROB) + inp1 = sample_reference_inputs( + mg1, inp1, z_q1, struct, target_density, 1, rollout_steps + ) + inp2 = sample_reference_inputs( + mg2, inp2, z_q2, struct, target_density, 2, rollout_steps + ) + inp3 = sample_reference_inputs( + mg3, inp3, z_q3, struct, target_density, 3, rollout_steps + ) remain1 = compute_remaining(inp1, target_density, 1) remain2 = compute_remaining(inp2, target_density, 2) remain3 = compute_remaining(inp3, target_density, 3) - # 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和动态 remain 条件) - logits1 = mg1(inp1, z_q, struct, remain1) - logits2 = mg2(inp2, z_q, struct, remain2) - logits3 = mg3(inp3, z_q, struct, remain3) + # 三阶段 MaskGIT 前向:各阶段只接收自己的 z_q、struct 和动态 remain 条件 + logits1 = mg1(inp1, z_q1, struct, remain1) + logits2 = mg2(inp2, z_q2, struct, remain2) + logits3 = mg3(inp3, z_q3, struct, remain3) - # 三阶段 Focal Loss + VQ commit loss 加权求和 - loss1 = focal_loss(logits1, target1) - loss2 = focal_loss(logits2, target2) - loss3 = focal_loss(logits3, target3) - loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1 - loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2 - loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3 + # 三阶段 Cross Entropy + VQ commit loss 加权求和 + loss1 = cross_entropy_loss(logits1, target1) + loss2 = cross_entropy_loss(logits2, target2) + loss3 = cross_entropy_loss(logits3, target3) + loss1_weighted = STAGE1_CE_WEIGHT * loss1 + loss2_weighted = STAGE2_CE_WEIGHT * loss2 + loss3_weighted = STAGE3_CE_WEIGHT * loss3 commit_weighted = VQ_BETA * commit_loss loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted @@ -936,11 +1059,13 @@ def train(device: torch.device): loss2_total += loss2.detach() loss3_total += loss3.detach() commit_total += commit_loss.detach() + code_hits_total += code_hits.detach() # 每个 epoch 结束后更新学习率 scheduler.step() data_length = len(dataloader) + train_perplexity, train_usage_rate, train_active_codes = summarize_codebook_hits(code_hits_total) tqdm.write( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | " @@ -948,20 +1073,23 @@ def train(device: torch.device): f"L2: {loss2_total.item() / data_length:.6f} | " f"L3: {loss3_total.item() / data_length:.6f} | " f"VQ: {commit_total.item() / data_length:.6f} | " + f"PPL: {train_perplexity:.4f} | " + f"Usage: {train_usage_rate:.4f} ({train_active_codes}/{code_hits_total.numel()}) | " f"LR: {scheduler.get_last_lr()[0]:.6f}" ) # 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存 if (epoch + 1) % CHECKPOINT == 0: losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1) - loss1_total, loss2_total, loss3_total, commit_total = losses - loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total - loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total - loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total + loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses + loss1_weighted = STAGE1_CE_WEIGHT * loss1_total + loss2_weighted = STAGE2_CE_WEIGHT * loss2_total + loss3_weighted = STAGE3_CE_WEIGHT * loss3_total commit_weighted = VQ_BETA * commit_total loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted data_length = len(dataloader_val) + val_perplexity, val_usage_rate, val_active_codes = summarize_codebook_hits(code_hits_total) tqdm.write( f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | " @@ -969,6 +1097,8 @@ def train(device: torch.device): f"L2: {loss2_total.item() / data_length:.6f} | " f"L3: {loss3_total.item() / data_length:.6f} | " f"VQ: {commit_total.item() / data_length:.6f} | " + f"PPL: {val_perplexity:.4f} | " + f"Usage: {val_usage_rate:.4f} ({val_active_codes}/{code_hits_total.numel()}) | " ) ckpt_path = f"result/seperated/sep-{epoch + 1}.pth" @@ -980,7 +1110,9 @@ def train(device: torch.device): "mg1": mg1.state_dict(), "mg2": mg2.state_dict(), "mg3": mg3.state_dict(), - "quantizer": quantizer.state_dict(), + "quantizer1": quantizer1.state_dict(), + "quantizer2": quantizer2.state_dict(), + "quantizer3": quantizer3.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, ckpt_path) @@ -996,7 +1128,9 @@ def train(device: torch.device): "mg1": mg1.state_dict(), "mg2": mg2.state_dict(), "mg3": mg3.state_dict(), - "quantizer": quantizer.state_dict(), + "quantizer1": quantizer1.state_dict(), + "quantizer2": quantizer2.state_dict(), + "quantizer3": quantizer3.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), }, final_path) diff --git a/ginka/vqvae/quantize.py b/ginka/vqvae/quantize.py index 7a8c0ab..b9dcaf7 100644 --- a/ginka/vqvae/quantize.py +++ b/ginka/vqvae/quantize.py @@ -18,8 +18,23 @@ class VectorQuantizer(nn.Module): self.codebook = nn.Embedding(K, d_z) nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K) - def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # z_e: [B, L * 3, d_z] + def codebook_stats( + self, indices: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_indices = indices.reshape(-1) + one_hot = F.one_hot(flat_indices, num_classes=self.K).float() + avg_probs = one_hot.mean(dim=0) + perplexity = torch.exp( + -(avg_probs * torch.log(avg_probs.clamp_min(1e-10))).sum() + ) + usage_rate = (avg_probs > 0).float().mean() + usage_count = one_hot.sum(dim=0) + return perplexity, usage_rate, usage_count + + def forward( + self, z_e: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # z_e: [B, L, d_z] """ Args: z_e: [B, L, d_z] 编码器输出的连续向量序列 @@ -31,7 +46,7 @@ class VectorQuantizer(nn.Module): """ B, L, d_z = z_e.shape - z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z] + z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z] codebook_w = self.codebook.weight # [K, d_z] @@ -56,13 +71,9 @@ class VectorQuantizer(nn.Module): commit_loss = F.mse_loss(z_e, z_q.detach()) indices = indices.reshape(B, L) - return z_q_st, indices, commit_loss + perplexity, usage_rate, usage_count = self.codebook_stats(indices) + return z_q_st, indices, commit_loss, perplexity, usage_count def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor: - indices1 = torch.randint(0, self.K, (B, L), device=device) - indices2 = torch.randint(0, self.K, (B, L), device=device) - indices3 = torch.randint(0, self.K, (B, L), device=device) - z1 = self.codebook(indices1) - z2 = self.codebook(indices2) - z3 = self.codebook(indices3) - return torch.cat([z1, z2, z3], dim=1) + indices = torch.randint(0, self.K, (B, L), device=device) + return self.codebook(indices)