mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 随机验证分为有条件和无条件
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
9bddb05625
commit
81d49944cc
@ -42,6 +42,7 @@ GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数
|
|||||||
MAP_SIZE = 13 * 13
|
MAP_SIZE = 13 * 13
|
||||||
MAP_H = MAP_W = 13
|
MAP_H = MAP_W = 13
|
||||||
LABEL_SMOOTHING = 0.0
|
LABEL_SMOOTHING = 0.0
|
||||||
|
WALL_MASK_RATIO = 0.8
|
||||||
|
|
||||||
# VQ-VAE 超参
|
# VQ-VAE 超参
|
||||||
VQ_L = 16 # summary token 数量(即 z 的序列长度)
|
VQ_L = 16 # summary token 数量(即 z 的序列长度)
|
||||||
@ -288,6 +289,21 @@ def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> t
|
|||||||
seed[0, idx] = 1 # wall
|
seed[0, idx] = 1 # wall
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def make_random_struct_cond() -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
生成一个随机结构条件,所有标签均取合法非-null 值。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[1, 4] LongTensor,顺序 [cond_sym, cond_room, cond_branch, cond_outer]
|
||||||
|
"""
|
||||||
|
from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB
|
||||||
|
sym = random.randint(0, SYM_VOCAB - 2) # 0-6
|
||||||
|
room = random.randint(0, ROOM_VOCAB - 2) # 0-2
|
||||||
|
branch = random.randint(0, BRANCH_VOCAB - 2) # 0-2
|
||||||
|
outer = random.randint(0, OUTER_VOCAB - 2) # 0-1
|
||||||
|
return torch.tensor([[sym, room, branch, outer]], dtype=torch.long, device=device)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def validate(
|
def validate(
|
||||||
model_vq: GinkaVQVAE,
|
model_vq: GinkaVQVAE,
|
||||||
@ -356,7 +372,7 @@ def validate(
|
|||||||
if all(v is not None for v in captured.values()):
|
if all(v is not None for v in captured.values()):
|
||||||
break
|
break
|
||||||
|
|
||||||
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成 ──────────────────
|
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(无条件)──────────────
|
||||||
def _rand_gens(cond_map, n):
|
def _rand_gens(cond_map, n):
|
||||||
imgs = []
|
imgs = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
@ -365,6 +381,18 @@ def validate(
|
|||||||
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
|
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(随机结构标签)────────
|
||||||
|
def _rand_gens_with_struct(cond_map, n):
|
||||||
|
imgs = []
|
||||||
|
for i in range(n):
|
||||||
|
z_r = model_vq.sample(1, device)
|
||||||
|
sc_r = make_random_struct_cond() # [1, 4] 随机合法标签
|
||||||
|
gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r)
|
||||||
|
img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")
|
||||||
|
img = annotate_struct(img, sc_r[0])
|
||||||
|
imgs.append(img)
|
||||||
|
return imgs
|
||||||
|
|
||||||
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
||||||
if captured['A'] is not None:
|
if captured['A'] is not None:
|
||||||
cap = captured['A']
|
cap = captured['A']
|
||||||
@ -446,11 +474,17 @@ def validate(
|
|||||||
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
|
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
|
||||||
|
|
||||||
# ── 场景5:完全随机生成(无数据集参照)──────────────────────────────────
|
# ── 场景5:完全随机生成(无数据集参照)──────────────────────────────────
|
||||||
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
|
# 5a:随机结构标签 — 验证结构导向能力
|
||||||
rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
|
rand_seed_a = make_random_wall_seed()
|
||||||
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
|
seed_img_a = label_image(make_map_image(rand_seed_a[0], tile_dict), "random seed")
|
||||||
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
|
row_a = [seed_img_a] + _rand_gens_with_struct(rand_seed_a, N_Z_SAMPLES + 1)
|
||||||
cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row))
|
cv2.imwrite(f"{epoch_dir}/scene5a_random_cond.png", grid_images(row_a))
|
||||||
|
|
||||||
|
# 5b:无条件(struct_cond=None)— 验证基线生成质量
|
||||||
|
rand_seed_b = make_random_wall_seed()
|
||||||
|
seed_img_b = label_image(make_map_image(rand_seed_b[0], tile_dict), "random seed")
|
||||||
|
row_b = [seed_img_b] + _rand_gens(rand_seed_b, N_Z_SAMPLES + 1)
|
||||||
|
cv2.imwrite(f"{epoch_dir}/scene5b_random_uncond.png", grid_images(row_b))
|
||||||
|
|
||||||
avg_val_loss = val_loss_total / max(val_steps, 1)
|
avg_val_loss = val_loss_total / max(val_steps, 1)
|
||||||
return avg_val_loss
|
return avg_val_loss
|
||||||
@ -492,12 +526,14 @@ def train():
|
|||||||
dataset_train = GinkaVQDataset(
|
dataset_train = GinkaVQDataset(
|
||||||
args.train,
|
args.train,
|
||||||
subset_weights=SUBSET_WEIGHTS,
|
subset_weights=SUBSET_WEIGHTS,
|
||||||
|
wall_mask_ratio=WALL_MASK_RATIO,
|
||||||
)
|
)
|
||||||
dataset_val = GinkaVQDataset(
|
dataset_val = GinkaVQDataset(
|
||||||
args.validate,
|
args.validate,
|
||||||
subset_weights=SUBSET_WEIGHTS,
|
subset_weights=SUBSET_WEIGHTS,
|
||||||
room_thresholds=dataset_train.room_th,
|
room_thresholds=dataset_train.room_th,
|
||||||
branch_thresholds=dataset_train.branch_th,
|
branch_thresholds=dataset_train.branch_th,
|
||||||
|
wall_mask_ratio=WALL_MASK_RATIO,
|
||||||
)
|
)
|
||||||
dataloader_train = DataLoader(
|
dataloader_train = DataLoader(
|
||||||
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user