mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 02:44:51 +08:00
chore: 完善验证流程
This commit is contained in:
parent
5f542fb577
commit
7c85b2b8cb
@ -158,7 +158,7 @@ def build_model(device: torch.device):
|
|||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||||
|
|
||||||
# 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表
|
# 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表
|
||||||
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z)
|
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
|
|
||||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
|
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
|
||||||
|
|
||||||
@ -181,8 +181,10 @@ def random_struct(device: torch.device) -> torch.Tensor:
|
|||||||
|
|
||||||
def maskgit_sample(
|
def maskgit_sample(
|
||||||
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
|
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
|
||||||
struct: torch.Tensor, steps: int
|
struct: torch.Tensor, steps: int, keep_fixed: bool = True
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
# keep_fixed=True:锁定输入中已有的非掩码位,使上一阶段结构保持不变
|
||||||
|
# keep_fixed=False:所有位置均可被模型自由重估,适合探索更多样的生成结果
|
||||||
current = inp.clone()
|
current = inp.clone()
|
||||||
|
|
||||||
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
|
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
|
||||||
@ -199,10 +201,11 @@ def maskgit_sample(
|
|||||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||||
num_to_mask = math.floor(ratio * MAP_SIZE)
|
num_to_mask = math.floor(ratio * MAP_SIZE)
|
||||||
|
|
||||||
# 输入中已有的非掩码位(来自上一阶段)保持不变
|
if keep_fixed:
|
||||||
fixed_mask = (current[0] != MASK_TOKEN)
|
# 输入中已有的非掩码位(来自上一阶段)保持不变
|
||||||
sampled[0, fixed_mask] = current[0, fixed_mask]
|
fixed_mask = (current[0] != MASK_TOKEN)
|
||||||
confidences[0, fixed_mask] = 1.0
|
sampled[0, fixed_mask] = current[0, fixed_mask]
|
||||||
|
confidences[0, fixed_mask] = 1.0
|
||||||
|
|
||||||
if num_to_mask > 0:
|
if num_to_mask > 0:
|
||||||
# 将置信度最低的位重新掩码,留待下一步重新预测
|
# 将置信度最低的位重新掩码,留待下一步重新预测
|
||||||
@ -226,7 +229,8 @@ def full_generate_random_z(
|
|||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
struct: torch.Tensor,
|
struct: torch.Tensor,
|
||||||
models: list[torch.nn.Module],
|
models: list[torch.nn.Module],
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
|
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||||
|
|
||||||
@ -234,19 +238,19 @@ def full_generate_random_z(
|
|||||||
z = quantizer.sample(1, VQ_L, device)
|
z = quantizer.sample(1, VQ_L, device)
|
||||||
|
|
||||||
# stage1:生成 floor/wall 骨架
|
# stage1:生成 floor/wall 骨架
|
||||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
|
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
||||||
|
|
||||||
# stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并
|
# stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并
|
||||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
|
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
|
||||||
merged12 = pred1_np.copy()
|
merged12 = pred1_np.copy()
|
||||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||||
inp3[inp3 == 0] = MASK_TOKEN
|
inp3[inp3 == 0] = MASK_TOKEN
|
||||||
|
|
||||||
# stage3:填充 resource
|
# stage3:填充 resource
|
||||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
|
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
|
||||||
merged123 = merged12.copy()
|
merged123 = merged12.copy()
|
||||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||||
|
|
||||||
@ -257,28 +261,43 @@ def full_generate_specific_z(
|
|||||||
z: torch.Tensor,
|
z: torch.Tensor,
|
||||||
struct: torch.Tensor,
|
struct: torch.Tensor,
|
||||||
models: list[torch.nn.Module],
|
models: list[torch.nn.Module],
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
|
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP)
|
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||||
inp2[inp2 == 0] = MASK_TOKEN
|
inp2[inp2 == 0] = MASK_TOKEN
|
||||||
|
|
||||||
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP)
|
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
|
||||||
merged12 = pred1_np.copy()
|
merged12 = pred1_np.copy()
|
||||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||||
inp3[inp3 == 0] = MASK_TOKEN
|
inp3[inp3 == 0] = MASK_TOKEN
|
||||||
|
|
||||||
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP)
|
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
|
||||||
merged123 = merged12.copy()
|
merged123 = merged12.copy()
|
||||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||||
|
|
||||||
return pred1_np, merged12, merged123
|
return pred1_np, merged12, merged123
|
||||||
|
|
||||||
|
def annotate(img: np.ndarray, text: str) -> np.ndarray:
|
||||||
|
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
|
||||||
|
img = img.copy()
|
||||||
|
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
|
||||||
|
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def rand_keep() -> tuple[bool, bool, bool]:
|
||||||
|
b = random.choice([True, False])
|
||||||
|
return (b, b, b)
|
||||||
|
|
||||||
|
def keep_label(kf: tuple[bool, bool, bool]) -> str:
|
||||||
|
return 'fix' if kf[0] else 'free'
|
||||||
|
|
||||||
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
||||||
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||||
SEP = 3
|
SEP = 3
|
||||||
@ -293,10 +312,6 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
|||||||
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||||
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
|
||||||
|
|
||||||
pred3_merged = pred1.copy()
|
|
||||||
pred3_merged[pred2 != 0] = pred2[pred2 != 0]
|
|
||||||
pred3_merged[pred3 != 0] = pred3[pred3 != 0]
|
|
||||||
|
|
||||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
@ -304,10 +319,18 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
|||||||
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
|
|
||||||
|
# 将各阶段掩码输入中的 MASK 位用模型预测值填充,保留非掩码位原值
|
||||||
|
result1 = inp1_np.copy()
|
||||||
|
result1[inp1_np == MASK_TOKEN] = pred1[inp1_np == MASK_TOKEN]
|
||||||
|
result2 = inp2_np.copy()
|
||||||
|
result2[inp2_np == MASK_TOKEN] = pred2[inp2_np == MASK_TOKEN]
|
||||||
|
result3 = inp3_np.copy()
|
||||||
|
result3[inp3_np == MASK_TOKEN] = pred3[inp3_np == MASK_TOKEN]
|
||||||
|
|
||||||
rows = [
|
rows = [
|
||||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||||
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
|
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
|
||||||
[to_img(pred1), to_img(pred2), to_img(pred3_merged)],
|
[to_img(result1), to_img(result2), to_img(result3)],
|
||||||
]
|
]
|
||||||
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||||
for r, row in enumerate(rows):
|
for r, row in enumerate(rows):
|
||||||
@ -329,9 +352,14 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
|||||||
|
|
||||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||||
struct_t = batch["struct_inject"][0:1].to(device)
|
struct_t = batch["struct_inject"][0:1].to(device)
|
||||||
|
kf = rand_keep()
|
||||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||||
inp1_t, z_q[0:1], struct_t, models, device
|
inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf
|
||||||
)
|
)
|
||||||
|
kf_label = 'fix' if kf[0] else 'free'
|
||||||
|
label1 = f"s1:{kf_label}"
|
||||||
|
label2 = f"s2:{kf_label}"
|
||||||
|
label3 = f"s3:{kf_label}"
|
||||||
|
|
||||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
@ -340,7 +368,7 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
|||||||
|
|
||||||
rows = [
|
rows = [
|
||||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||||
[to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)],
|
[to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)],
|
||||||
]
|
]
|
||||||
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
|
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
|
||||||
for r, row in enumerate(rows):
|
for r, row in enumerate(rows):
|
||||||
@ -366,13 +394,15 @@ def visualize_part3(batch, models, device, tile_dict):
|
|||||||
|
|
||||||
row1 = [to_img(inp1_np)]
|
row1 = [to_img(inp1_np)]
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device)
|
kf = rand_keep()
|
||||||
row1.append(to_img(merged123))
|
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf)
|
||||||
|
row1.append(annotate(to_img(merged123), keep_label(kf)))
|
||||||
|
|
||||||
row2 = []
|
row2 = []
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device)
|
kf = rand_keep()
|
||||||
row2.append(to_img(merged123))
|
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf)
|
||||||
|
row2.append(annotate(to_img(merged123), keep_label(kf)))
|
||||||
|
|
||||||
rows = [row1, row2]
|
rows = [row1, row2]
|
||||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||||
@ -401,8 +431,9 @@ def visualize_part4(models, device, tile_dict):
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device)
|
kf = rand_keep()
|
||||||
results.append(to_img(merged123))
|
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf)
|
||||||
|
results.append(annotate(to_img(merged123), keep_label(kf)))
|
||||||
|
|
||||||
row1 = [to_img(seed_np)] + results[:2]
|
row1 = [to_img(seed_np)] + results[:2]
|
||||||
row2 = results[2:]
|
row2 = results[2:]
|
||||||
@ -664,3 +695,6 @@ def train(device: torch.device):
|
|||||||
"scheduler": scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
}, final_path)
|
}, final_path)
|
||||||
tqdm.write(f"Training complete. Final model saved: {final_path}")
|
tqdm.write(f"Training complete. Final model saved: {final_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train(device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user