feat: 随机验证分为有条件和无条件

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-27 15:36:11 +08:00
parent 9bddb05625
commit 81d49944cc

View File

@ -42,6 +42,7 @@ GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数
MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13
LABEL_SMOOTHING = 0.0
WALL_MASK_RATIO = 0.8
# VQ-VAE 超参
VQ_L = 16 # summary token 数量(即 z 的序列长度)
@ -288,6 +289,21 @@ def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> t
seed[0, idx] = 1 # wall
return seed
def make_random_struct_cond() -> torch.Tensor:
"""
生成一个随机结构条件所有标签均取合法非-null
Returns:
[1, 4] LongTensor顺序 [cond_sym, cond_room, cond_branch, cond_outer]
"""
from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB
sym = random.randint(0, SYM_VOCAB - 2) # 0-6
room = random.randint(0, ROOM_VOCAB - 2) # 0-2
branch = random.randint(0, BRANCH_VOCAB - 2) # 0-2
outer = random.randint(0, OUTER_VOCAB - 2) # 0-1
return torch.tensor([[sym, room, branch, outer]], dtype=torch.long, device=device)
@torch.no_grad()
def validate(
model_vq: GinkaVQVAE,
@ -356,7 +372,7 @@ def validate(
if all(v is not None for v in captured.values()):
break
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成 ──────────────────
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(无条件)──────────────
def _rand_gens(cond_map, n):
imgs = []
for i in range(n):
@ -365,6 +381,18 @@ def validate(
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
return imgs
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(随机结构标签)────────
def _rand_gens_with_struct(cond_map, n):
imgs = []
for i in range(n):
z_r = model_vq.sample(1, device)
sc_r = make_random_struct_cond() # [1, 4] 随机合法标签
gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r)
img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")
img = annotate_struct(img, sc_r[0])
imgs.append(img)
return imgs
# ── 场景1标准掩码补全子集 A─────────────────────────────────────────
if captured['A'] is not None:
cap = captured['A']
@ -446,11 +474,17 @@ def validate(
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
# ── 场景5完全随机生成无数据集参照──────────────────────────────────
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row))
# 5a随机结构标签 — 验证结构导向能力
rand_seed_a = make_random_wall_seed()
seed_img_a = label_image(make_map_image(rand_seed_a[0], tile_dict), "random seed")
row_a = [seed_img_a] + _rand_gens_with_struct(rand_seed_a, N_Z_SAMPLES + 1)
cv2.imwrite(f"{epoch_dir}/scene5a_random_cond.png", grid_images(row_a))
# 5b无条件struct_cond=None— 验证基线生成质量
rand_seed_b = make_random_wall_seed()
seed_img_b = label_image(make_map_image(rand_seed_b[0], tile_dict), "random seed")
row_b = [seed_img_b] + _rand_gens(rand_seed_b, N_Z_SAMPLES + 1)
cv2.imwrite(f"{epoch_dir}/scene5b_random_uncond.png", grid_images(row_b))
avg_val_loss = val_loss_total / max(val_steps, 1)
return avg_val_loss
@ -492,12 +526,14 @@ def train():
dataset_train = GinkaVQDataset(
args.train,
subset_weights=SUBSET_WEIGHTS,
wall_mask_ratio=WALL_MASK_RATIO,
)
dataset_val = GinkaVQDataset(
args.validate,
subset_weights=SUBSET_WEIGHTS,
room_thresholds=dataset_train.room_th,
branch_thresholds=dataset_train.branch_th,
wall_mask_ratio=WALL_MASK_RATIO,
)
dataloader_train = DataLoader(
dataset_train, batch_size=BATCH_SIZE, shuffle=True,