mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
chore: 完善验证流程
This commit is contained in:
parent
5f542fb577
commit
7c85b2b8cb
@ -158,7 +158,7 @@ def build_model(device: torch.device):
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
# 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表
|
||||
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z)
|
||||
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
|
||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
|
||||
|
||||
@ -181,8 +181,10 @@ def random_struct(device: torch.device) -> torch.Tensor:
|
||||
|
||||
def maskgit_sample(
|
||||
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
|
||||
struct: torch.Tensor, steps: int
|
||||
struct: torch.Tensor, steps: int, keep_fixed: bool = True
|
||||
) -> np.ndarray:
|
||||
# keep_fixed=True:锁定输入中已有的非掩码位,使上一阶段结构保持不变
|
||||
# keep_fixed=False:所有位置均可被模型自由重估,适合探索更多样的生成结果
|
||||
current = inp.clone()
|
||||
|
||||
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
|
||||
@ -199,10 +201,11 @@ def maskgit_sample(
|
||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||
num_to_mask = math.floor(ratio * MAP_SIZE)
|
||||
|
||||
# 输入中已有的非掩码位(来自上一阶段)保持不变
|
||||
fixed_mask = (current[0] != MASK_TOKEN)
|
||||
sampled[0, fixed_mask] = current[0, fixed_mask]
|
||||
confidences[0, fixed_mask] = 1.0
|
||||
if keep_fixed:
|
||||
# 输入中已有的非掩码位(来自上一阶段)保持不变
|
||||
fixed_mask = (current[0] != MASK_TOKEN)
|
||||
sampled[0, fixed_mask] = current[0, fixed_mask]
|
||||
confidences[0, fixed_mask] = 1.0
|
||||
|
||||
if num_to_mask > 0:
|
||||
# 将置信度最低的位重新掩码,留待下一步重新预测
|
||||
@ -226,7 +229,8 @@ def full_generate_random_z(
|
||||
input: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
@ -234,19 +238,19 @@ def full_generate_random_z(
|
||||
z = quantizer.sample(1, VQ_L, device)
|
||||
|
||||
# stage1:生成 floor/wall 骨架
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
||||
|
||||
# stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并
|
||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
|
||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
# stage3:填充 resource
|
||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
|
||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
|
||||
@ -257,28 +261,43 @@ def full_generate_specific_z(
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
|
||||
with torch.no_grad():
|
||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN
|
||||
|
||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
|
||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
|
||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
|
||||
return pred1_np, merged12, merged123
|
||||
|
||||
def annotate(img: np.ndarray, text: str) -> np.ndarray:
|
||||
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
|
||||
img = img.copy()
|
||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
|
||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
||||
return img
|
||||
|
||||
def rand_keep() -> tuple[bool, bool, bool]:
|
||||
b = random.choice([True, False])
|
||||
return (b, b, b)
|
||||
|
||||
def keep_label(kf: tuple[bool, bool, bool]) -> str:
|
||||
return 'fix' if kf[0] else 'free'
|
||||
|
||||
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
||||
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||
SEP = 3
|
||||
@ -293,10 +312,6 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
pred3_merged = pred1.copy()
|
||||
pred3_merged[pred2 != 0] = pred2[pred2 != 0]
|
||||
pred3_merged[pred3 != 0] = pred3[pred3 != 0]
|
||||
|
||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
@ -304,10 +319,18 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
# 将各阶段掩码输入中的 MASK 位用模型预测值填充,保留非掩码位原值
|
||||
result1 = inp1_np.copy()
|
||||
result1[inp1_np == MASK_TOKEN] = pred1[inp1_np == MASK_TOKEN]
|
||||
result2 = inp2_np.copy()
|
||||
result2[inp2_np == MASK_TOKEN] = pred2[inp2_np == MASK_TOKEN]
|
||||
result3 = inp3_np.copy()
|
||||
result3[inp3_np == MASK_TOKEN] = pred3[inp3_np == MASK_TOKEN]
|
||||
|
||||
rows = [
|
||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
|
||||
[to_img(pred1), to_img(pred2), to_img(pred3_merged)],
|
||||
[to_img(result1), to_img(result2), to_img(result3)],
|
||||
]
|
||||
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
@ -329,9 +352,14 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
kf = rand_keep()
|
||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, models, device
|
||||
inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf
|
||||
)
|
||||
kf_label = 'fix' if kf[0] else 'free'
|
||||
label1 = f"s1:{kf_label}"
|
||||
label2 = f"s2:{kf_label}"
|
||||
label3 = f"s3:{kf_label}"
|
||||
|
||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
@ -340,7 +368,7 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
|
||||
rows = [
|
||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||
[to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)],
|
||||
[to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)],
|
||||
]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
@ -366,13 +394,15 @@ def visualize_part3(batch, models, device, tile_dict):
|
||||
|
||||
row1 = [to_img(inp1_np)]
|
||||
for _ in range(2):
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device)
|
||||
row1.append(to_img(merged123))
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf)
|
||||
row1.append(annotate(to_img(merged123), keep_label(kf)))
|
||||
|
||||
row2 = []
|
||||
for _ in range(3):
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device)
|
||||
row2.append(to_img(merged123))
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf)
|
||||
row2.append(annotate(to_img(merged123), keep_label(kf)))
|
||||
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
@ -401,8 +431,9 @@ def visualize_part4(models, device, tile_dict):
|
||||
|
||||
results = []
|
||||
for _ in range(5):
|
||||
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device)
|
||||
results.append(to_img(merged123))
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf)
|
||||
results.append(annotate(to_img(merged123), keep_label(kf)))
|
||||
|
||||
row1 = [to_img(seed_np)] + results[:2]
|
||||
row2 = results[2:]
|
||||
@ -664,3 +695,6 @@ def train(device: torch.device):
|
||||
"scheduler": scheduler.state_dict(),
|
||||
}, final_path)
|
||||
tqdm.write(f"Training complete. Final model saved: {final_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train(device)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user