From 14ee52fb2fde56e03422d1ff01499cdf0086ec5f Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 20 May 2026 18:02:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BB=8E=E8=AE=AD=E7=BB=83=E9=9B=86?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E6=8A=BD=E6=A0=B7=E7=94=9F=E6=88=90=EF=BC=8C?= =?UTF-8?q?=E8=80=8C=E4=B8=8D=E6=98=AF=E5=AE=8C=E5=85=A8=E9=9A=8F=E6=9C=BA?= =?UTF-8?q?=E9=87=87=E6=A0=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 83 ++++++++---- ginka/train_seperated.py | 271 ++++++++++----------------------------- 2 files changed, 121 insertions(+), 233 deletions(-) diff --git a/ginka/dataset.py b/ginka/dataset.py index b616494..86695a9 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -71,6 +71,61 @@ class GinkaSeperatedDataset(Dataset): def __len__(self): return len(self.data) + def build_struct_inject(self, map_np: np.ndarray, outer_wall: int) -> torch.Tensor: + sym_h, sym_v, sym_c = compute_symmetry(map_np) + cond_sym = sym_h * 4 + sym_v * 2 + sym_c + return torch.LongTensor([cond_sym, outer_wall]) + + def build_target_density(self, map_data: list) -> torch.Tensor: + return torch.FloatTensor([ + self.count_tile(map_data, self.WALL) / self.MAP_SIZE, + self.count_tile(map_data, self.DOOR) / self.MAP_SIZE, + self.count_tile(map_data, self.MONSTER) / self.MAP_SIZE, + self.count_tile(map_data, self.ENTRANCE) / self.MAP_SIZE, + self.count_tile(map_data, self.RESOURCE) / self.MAP_SIZE + ]) + + def build_encoder_inputs(self, raw: np.ndarray) -> tuple: + target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw.copy()) + enc1 = target1.copy() + enc2 = inp2.copy() + enc3 = raw.copy() + return enc1, enc2, enc3 + + def pack_sample(self, item: dict, map_np: np.ndarray, out: tuple) -> dict: + return { + "input_stage1": torch.LongTensor(out[0]), + "target_stage1": torch.LongTensor(out[1]), + "encoder_stage1": torch.LongTensor(out[2]), + "input_stage2": torch.LongTensor(out[3]), + "target_stage2": torch.LongTensor(out[4]), + "encoder_stage2": torch.LongTensor(out[5]), + "input_stage3": torch.LongTensor(out[6]), + "target_stage3": torch.LongTensor(out[7]), + "encoder_stage3": torch.LongTensor(out[8]), + "struct_inject": self.build_struct_inject(map_np, item['outerWall']), + "target_density": self.build_target_density(item['map']) + } + + def random_sample_map(self, idx: int | None = None) -> dict: + if idx is None: + idx = random.randrange(len(self.data)) + + item = self.data[idx] + map_np = np.array(item['map'], dtype=np.int64) + + enc1, enc2, enc3 = self.build_encoder_inputs(map_np) + sample = { + "encoder_stage1": torch.LongTensor(enc1), + "encoder_stage2": torch.LongTensor(enc2), + "encoder_stage3": torch.LongTensor(enc3), + "struct_inject": self.build_struct_inject(map_np, item['outerWall']), + "target_density": self.build_target_density(item['map']), + "raw_map": torch.LongTensor(map_np) + } + sample['sample_idx'] = idx + return sample + def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray: # 将指定 tile ID 替换为 floor(0),原地修改 for t in tiles: @@ -179,30 +234,4 @@ class GinkaSeperatedDataset(Dataset): else: out = self.apply_subset3(map_np) - sym_h, sym_v, sym_c = compute_symmetry(map_np) - cond_sym = sym_h * 4 + sym_v * 2 + sym_c - cond_outer = item['outerWall'] - struct_inject = torch.LongTensor([cond_sym, cond_outer]) - - m = item['map'] - target_density = torch.FloatTensor([ - self.count_tile(m, self.WALL) / self.MAP_SIZE, - self.count_tile(m, self.DOOR) / self.MAP_SIZE, - self.count_tile(m, self.MONSTER) / self.MAP_SIZE, - self.count_tile(m, self.ENTRANCE) / self.MAP_SIZE, - self.count_tile(m, self.RESOURCE) / self.MAP_SIZE, - ]) - - return { - "input_stage1": torch.LongTensor(out[0]), - "target_stage1": torch.LongTensor(out[1]), - "encoder_stage1": torch.LongTensor(out[2]), - "input_stage2": torch.LongTensor(out[3]), - "target_stage2": torch.LongTensor(out[4]), - "encoder_stage2": torch.LongTensor(out[5]), - "input_stage3": torch.LongTensor(out[6]), - "target_stage3": torch.LongTensor(out[7]), - "encoder_stage3": torch.LongTensor(out[8]), - "struct_inject": struct_inject, - "target_density": target_density - } + return self.pack_sample(item, map_np, out) diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index c712938..8b0cdca 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -38,7 +38,7 @@ from shared.image import matrix_to_image_cv # 共用 VQ-VAE 超参 # 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码 VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3) -VQ_K = 16 # codebook 大小(离散码本条目数) +VQ_K = 32 # codebook 大小(离散码本条目数) VQ_D_Z = 64 # 码字维度 VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook) VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用) @@ -267,28 +267,6 @@ def sample_reference_inputs( return sampled_reference -def random_struct(device: torch.device) -> torch.Tensor: - # 随机采样一组结构参量,用于无条件自由生成 - # struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)] - cond_sym = random.randint(0, 7) # 地图对称类型 - cond_outer = random.randint(0, 1) # 是否有外围走廈 - return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device) - -def random_target_density(density_stats: dict, device: torch.device) -> torch.Tensor: - # 从训练集真实密度范围中采样 wall / door / monster / entrance / resource 目标密度 - wall_density = random.uniform(density_stats["wall_min_density"], density_stats["wall_max_density"]) - door_density = random.uniform(density_stats["door_min_density"], density_stats["door_max_density"]) - monster_density = random.uniform(density_stats["monster_min_density"], density_stats["monster_max_density"]) - entrance_density = random.uniform(density_stats["entrance_min_density"], density_stats["entrance_max_density"]) - resource_density = random.uniform(density_stats["resource_min_density"], density_stats["resource_max_density"]) - return torch.FloatTensor([ - wall_density, - door_density, - monster_density, - entrance_density, - resource_density, - ]).unsqueeze(0).to(device) - def compute_remaining( current: torch.Tensor, target_density: torch.Tensor, @@ -416,50 +394,6 @@ def maskgit_sample( return current[0].cpu().numpy().reshape(MAP_H, MAP_W) -def full_generate_random_z( - input: torch.Tensor, - struct: torch.Tensor, - target_density: torch.Tensor, - models: list[torch.nn.Module], - device: torch.device, - keep_fixed: tuple[bool, bool, bool] = (True, True, True) -) -> tuple: - vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models - quantizer1, quantizer2, quantizer3 = quantizers - - with torch.no_grad(): - z1 = quantizer1.sample(1, VQ_L, device) - z2 = quantizer2.sample(1, VQ_L, device) - z3 = quantizer3.sample(1, VQ_L, device) - - # stage1:生成墙壁骨架 - pred1_np = maskgit_sample( - mg1, input.clone(), z1, struct, target_density, 1, - GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0] - ) - inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) - inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充 - - # stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并 - pred2_np = maskgit_sample( - mg2, inp2, z2, struct, target_density, 2, - GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] - ) - merged12 = pred1_np.copy() - merged12[pred2_np != 0] = pred2_np[pred2_np != 0] - inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) - inp3[inp3 == 0] = MASK_TOKEN - - # stage3:填充 resource(3) - pred3_np = maskgit_sample( - mg3, inp3, z3, struct, target_density, 3, - GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] - ) - merged123 = merged12.copy() - merged123[pred3_np != 0] = pred3_np[pred3_np != 0] - - return pred1_np, merged12, merged123 - def full_generate_specific_z( input: torch.Tensor, z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor], @@ -474,7 +408,7 @@ def full_generate_specific_z( with torch.no_grad(): z1, z2, z3 = z_q - # 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z + # 三阶段级联生成,但使用给定的 z pred1_np = maskgit_sample( mg1, input.clone(), z1, struct, target_density, 1, GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0] @@ -500,11 +434,11 @@ def full_generate_specific_z( return pred1_np, merged12, merged123 -def annotate(img: np.ndarray, text: str) -> np.ndarray: +def annotate(img: np.ndarray, text: str, y: int = 14) -> np.ndarray: # 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读) img = img.copy() - cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2) - cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) + cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2) + cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) return img def annotate_labels( @@ -531,6 +465,40 @@ def rand_keep() -> tuple[bool, bool, bool]: def keep_label(kf: tuple[bool, bool, bool]) -> str: return 'fix' if kf[0] else 'free' +def build_dataset_sample_case( + dataset: GinkaSeperatedDataset, + models: list[torch.nn.Module], + device: torch.device, + idx: int | None = None +) -> dict: + vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models + sample = dataset.random_sample_map(idx=idx) + + enc1_t = sample["encoder_stage1"].to(device).reshape(1, MAP_SIZE) + enc2_t = sample["encoder_stage2"].to(device).reshape(1, MAP_SIZE) + enc3_t = sample["encoder_stage3"].to(device).reshape(1, MAP_SIZE) + struct_t = sample["struct_inject"].to(device).reshape(1, -1) + target_density_t = sample["target_density"].to(device).reshape(1, -1) + + with torch.no_grad(): + z_e1 = vq1(enc1_t) + z_e2 = vq2(enc2_t) + z_e3 = vq3(enc3_t) + z_q, commit_loss, code_hits = quantize_stage_latents( + quantizers, z_e1, z_e2, z_e3 + ) + + return { + "sample": sample, + "struct": struct_t, + "target_density": target_density_t, + "z_q": z_q, + "sample_idx": sample["sample_idx"] + } + +def sample_case_label(case: dict) -> str: + return f"train#{case['sample_idx']}" + # 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并) def visualize_part1(batch, logits1, logits2, logits3, tile_dict): SEP = 3 @@ -618,52 +586,13 @@ def visualize_part2(batch, z_q, models, device, tile_dict): grid[y:y + img_h, x:x + img_w] = img return grid -# 验证可视化 part3:2×3 网格;行1=参考输入+相同 struct 随机 z 生成,行2=随机 struct 生成 -def visualize_part3(batch, models, device, tile_dict, density_stats: dict): - SEP = 3 - TILE_SIZE = 32 - img_h = MAP_H * TILE_SIZE - img_w = MAP_W * TILE_SIZE - - def to_img(mat): - return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) - - inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) - struct_ref = batch["struct_inject"][0:1].to(device) - target_density_ref = batch["target_density"][0:1].to(device) - inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W) - struct_cpu = batch["struct_inject"][0] - target_density_cpu = batch["target_density"][0] - - row1 = [to_img(inp1_np)] - for _ in range(2): - kf = rand_keep() - _, _, merged123 = full_generate_random_z( - inp1_t, struct_ref, target_density_ref, models, device, keep_fixed=kf - ) - row1.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu)) - - row2 = [] - for _ in range(3): - kf = rand_keep() - rnd_struct = random_struct(device) - rnd_target_density = random_target_density(density_stats, device) - _, _, merged123 = full_generate_random_z( - inp1_t, rnd_struct, rnd_target_density, models, device, keep_fixed=kf - ) - row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu())) - - rows = [row1, row2] - grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 - for r, row in enumerate(rows): - for c, img in enumerate(row): - y = SEP + r * (img_h + SEP) - x = SEP + c * (img_w + SEP) - grid[y:y + img_h, x:x + img_w] = img - return grid - -# 验证可视化 part4:2×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成 -def visualize_part4(models, device, tile_dict, density_stats: dict): +# 验证可视化 part4:2×3 网格;保留稀疏墙壁种子,但 z 与标签来自训练集样本 +def visualize_part4( + train_dataset: GinkaSeperatedDataset, + models: list[torch.nn.Module], + device: torch.device, + tile_dict +): SEP = 3 TILE_SIZE = 32 img_h = MAP_H * TILE_SIZE @@ -680,15 +609,21 @@ def visualize_part4(models, device, tile_dict, density_stats: dict): results = [] for _ in range(5): + case = build_dataset_sample_case(train_dataset, models, device) kf = rand_keep() - rnd_struct = random_struct(device) - rnd_target_density = random_target_density(density_stats, device) - _, _, merged123 = full_generate_random_z( - seed, rnd_struct, rnd_target_density, models, device, keep_fixed=kf + sample = case["sample"] + _, _, merged123 = full_generate_specific_z( + seed, case["z_q"], case["struct"], case["target_density"], + models, device, keep_fixed=kf + ) + result = annotate_labels( + to_img(merged123), sample["struct_inject"], sample["target_density"] + ) + results.append( + annotate(result, f"{sample_case_label(case)} {keep_label(kf)}", y=50) ) - results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu())) - row1 = [to_img(seed_np)] + results[:2] + row1 = [annotate(to_img(seed_np), 'seed')] + results[:2] row2 = results[2:] rows = [row1, row2] grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 @@ -702,92 +637,20 @@ def visualize_part4(models, device, tile_dict, density_stats: dict): def visualize_validate( batch, logits1, logits2, logits3, z_q, models: list[torch.nn.Module], device: torch.device, tile_dict, - density_stats: dict, epoch: int, batch_idx: int + train_dataset: GinkaSeperatedDataset, epoch: int, batch_idx: int ): save_dir = f"result/seperated/e{epoch}" os.makedirs(save_dir, exist_ok=True) cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict)) cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict)) - cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict, density_stats)) - cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict)) - -# 密度对照图:随机种子+随机结构,5 张随机密度生成,2×3 网格(左上角为种子图) -def visualize_density_cmp(models, device, tile_dict, density_stats: dict): - SEP = 3 - TILE_SIZE = 32 - img_h = MAP_H * TILE_SIZE - img_w = MAP_W * TILE_SIZE - - def to_img(mat): - return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) - - n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06)) - seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) - wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls] - seed[0, wall_pos] = 1 - seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W) - rnd_struct = random_struct(device) - struct_cpu = rnd_struct[0].cpu() - gen_imgs = [] - for _ in range(5): - rnd_target_density = random_target_density(density_stats, device) - target_density_cpu = rnd_target_density[0].cpu() - _, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_target_density, models, device) - gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu)) - row1 = [to_img(seed_np)] + gen_imgs[:2] - row2 = gen_imgs[2:] - rows = [row1, row2] - grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 - for r, row in enumerate(rows): - for c, img in enumerate(row): - y = SEP + r * (img_h + SEP) - x = SEP + c * (img_w + SEP) - grid[y:y + img_h, x:x + img_w] = img - return grid - -# 固定 z 和结构条件,扫描 5 个不同墙壁目标密度,2×3 网格(左上角为参考地图) -def visualize_density_var(batch, z_q, models, device, tile_dict): - SEP = 3 - TILE_SIZE = 32 - img_h = MAP_H * TILE_SIZE - img_w = MAP_W * TILE_SIZE - - def to_img(mat): - return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) - - inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) - struct_t = batch["struct_inject"][0:1].to(device) - struct_cpu = batch["struct_inject"][0] - base_target_density = batch["target_density"][0:1].to(device) - z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1]) - ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) - gen_imgs = [] - wall_count_values = [20, 35, 50, 65, 80] - for wall_count in wall_count_values: - fixed_target_density = base_target_density.clone() - fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE - target_density_cpu = fixed_target_density[0].cpu() - _, _, merged123 = full_generate_specific_z( - inp1_t, z_q_single, struct_t, fixed_target_density, models, device - ) - gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu)) - row1 = [to_img(ref_np)] + gen_imgs[:2] - row2 = gen_imgs[2:] - rows = [row1, row2] - grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 - for r, row in enumerate(rows): - for c, img in enumerate(row): - y = SEP + r * (img_h + SEP) - x = SEP + c * (img_w + SEP) - grid[y:y + img_h, x:x + img_w] = img - return grid + cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part4(train_dataset, models, device, tile_dict)) def validate( dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, - density_stats: dict, + train_dataset: GinkaSeperatedDataset, epoch: int ): vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models @@ -885,7 +748,7 @@ def validate( # 每个 batch 生成三种可视化图(val/full/rand) visualize_validate( batch, logits1, logits2, logits3, z_q, - models, device, tile_dict, density_stats, epoch, idx + models, device, tile_dict, train_dataset, epoch, idx ) idx += 1 @@ -896,12 +759,6 @@ def validate( avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0 tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}") - save_dir = f"result/seperated/e{epoch}" - os.makedirs(save_dir, exist_ok=True) - # 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图 - cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict, density_stats)) - cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict, density_stats)) - # 恢复训练模式 for m in [vq1, vq2, vq3, mg1, mg2, mg3]: m.train() @@ -1080,7 +937,9 @@ def train(device: torch.device): # 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存 if (epoch + 1) % CHECKPOINT == 0: - losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1) + losses = validate( + dataloader_val, models, device, tile_dict, dataset, epoch + 1 + ) loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses loss1_weighted = STAGE1_CE_WEIGHT * loss1_total loss2_weighted = STAGE2_CE_WEIGHT * loss2_total