chore: 完善验证流程

This commit is contained in:
unanmed 2026-05-13 21:05:44 +08:00
parent 5f542fb577
commit 7c85b2b8cb

View File

@ -158,7 +158,7 @@ def build_model(device: torch.device):
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
# 共用 VectorQuantizer不参与梯度更新仅在前向时做码本查表
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z)
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
@ -181,8 +181,10 @@ def random_struct(device: torch.device) -> torch.Tensor:
def maskgit_sample(
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
struct: torch.Tensor, steps: int
struct: torch.Tensor, steps: int, keep_fixed: bool = True
) -> np.ndarray:
# keep_fixed=True锁定输入中已有的非掩码位使上一阶段结构保持不变
# keep_fixed=False所有位置均可被模型自由重估适合探索更多样的生成结果
current = inp.clone()
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
@ -199,10 +201,11 @@ def maskgit_sample(
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_SIZE)
# 输入中已有的非掩码位(来自上一阶段)保持不变
fixed_mask = (current[0] != MASK_TOKEN)
sampled[0, fixed_mask] = current[0, fixed_mask]
confidences[0, fixed_mask] = 1.0
if keep_fixed:
# 输入中已有的非掩码位(来自上一阶段)保持不变
fixed_mask = (current[0] != MASK_TOKEN)
sampled[0, fixed_mask] = current[0, fixed_mask]
confidences[0, fixed_mask] = 1.0
if num_to_mask > 0:
# 将置信度最低的位重新掩码,留待下一步重新预测
@ -226,7 +229,8 @@ def full_generate_random_z(
input: torch.Tensor,
struct: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
@ -234,19 +238,19 @@ def full_generate_random_z(
z = quantizer.sample(1, VQ_L, device)
# stage1生成 floor/wall 骨架
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
# stage2在骨架上生成 door/monster/entrance非零结果覆盖合并
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
# stage3填充 resource
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
@ -257,28 +261,43 @@ def full_generate_specific_z(
z: torch.Tensor,
struct: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
with torch.no_grad():
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
return pred1_np, merged12, merged123
def annotate(img: np.ndarray, text: str) -> np.ndarray:
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
img = img.copy()
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
return img
def rand_keep() -> tuple[bool, bool, bool]:
b = random.choice([True, False])
return (b, b, b)
def keep_label(kf: tuple[bool, bool, bool]) -> str:
return 'fix' if kf[0] else 'free'
# 验证可视化 part13×3 网格行1=编码器输入行2=掩码输入行3=三阶段预测(合并)
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
SEP = 3
@ -293,10 +312,6 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred3_merged = pred1.copy()
pred3_merged[pred2 != 0] = pred2[pred2 != 0]
pred3_merged[pred3 != 0] = pred3[pred3 != 0]
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
@ -304,10 +319,18 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
# 将各阶段掩码输入中的 MASK 位用模型预测值填充,保留非掩码位原值
result1 = inp1_np.copy()
result1[inp1_np == MASK_TOKEN] = pred1[inp1_np == MASK_TOKEN]
result2 = inp2_np.copy()
result2[inp2_np == MASK_TOKEN] = pred2[inp2_np == MASK_TOKEN]
result3 = inp3_np.copy()
result3[inp3_np == MASK_TOKEN] = pred3[inp3_np == MASK_TOKEN]
rows = [
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
[to_img(pred1), to_img(pred2), to_img(pred3_merged)],
[to_img(result1), to_img(result2), to_img(result3)],
]
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
@ -329,9 +352,14 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_t = batch["struct_inject"][0:1].to(device)
kf = rand_keep()
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
inp1_t, z_q[0:1], struct_t, models, device
inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf
)
kf_label = 'fix' if kf[0] else 'free'
label1 = f"s1:{kf_label}"
label2 = f"s2:{kf_label}"
label3 = f"s3:{kf_label}"
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
@ -340,7 +368,7 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
rows = [
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
[to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)],
[to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)],
]
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
@ -366,13 +394,15 @@ def visualize_part3(batch, models, device, tile_dict):
row1 = [to_img(inp1_np)]
for _ in range(2):
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device)
row1.append(to_img(merged123))
kf = rand_keep()
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf)
row1.append(annotate(to_img(merged123), keep_label(kf)))
row2 = []
for _ in range(3):
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device)
row2.append(to_img(merged123))
kf = rand_keep()
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf)
row2.append(annotate(to_img(merged123), keep_label(kf)))
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
@ -401,8 +431,9 @@ def visualize_part4(models, device, tile_dict):
results = []
for _ in range(5):
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device)
results.append(to_img(merged123))
kf = rand_keep()
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf)
results.append(annotate(to_img(merged123), keep_label(kf)))
row1 = [to_img(seed_np)] + results[:2]
row2 = results[2:]
@ -664,3 +695,6 @@ def train(device: torch.device):
"scheduler": scheduler.state_dict(),
}, final_path)
tqdm.write(f"Training complete. Final model saved: {final_path}")
if __name__ == "__main__":
train(device)