mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 10:21:15 +08:00
Compare commits
2 Commits
416aa4dd72
...
f006522cf9
| Author | SHA1 | Date | |
|---|---|---|---|
| f006522cf9 | |||
| 14ee52fb2f |
@ -71,6 +71,61 @@ class GinkaSeperatedDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def build_struct_inject(self, map_np: np.ndarray, outer_wall: int) -> torch.Tensor:
|
||||
sym_h, sym_v, sym_c = compute_symmetry(map_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
||||
return torch.LongTensor([cond_sym, outer_wall])
|
||||
|
||||
def build_target_density(self, map_data: list) -> torch.Tensor:
|
||||
return torch.FloatTensor([
|
||||
self.count_tile(map_data, self.WALL) / self.MAP_SIZE,
|
||||
self.count_tile(map_data, self.DOOR) / self.MAP_SIZE,
|
||||
self.count_tile(map_data, self.MONSTER) / self.MAP_SIZE,
|
||||
self.count_tile(map_data, self.ENTRANCE) / self.MAP_SIZE,
|
||||
self.count_tile(map_data, self.RESOURCE) / self.MAP_SIZE
|
||||
])
|
||||
|
||||
def build_encoder_inputs(self, raw: np.ndarray) -> tuple:
|
||||
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw.copy())
|
||||
enc1 = target1.copy()
|
||||
enc2 = inp2.copy()
|
||||
enc3 = raw.copy()
|
||||
return enc1, enc2, enc3
|
||||
|
||||
def pack_sample(self, item: dict, map_np: np.ndarray, out: tuple) -> dict:
|
||||
return {
|
||||
"input_stage1": torch.LongTensor(out[0]),
|
||||
"target_stage1": torch.LongTensor(out[1]),
|
||||
"encoder_stage1": torch.LongTensor(out[2]),
|
||||
"input_stage2": torch.LongTensor(out[3]),
|
||||
"target_stage2": torch.LongTensor(out[4]),
|
||||
"encoder_stage2": torch.LongTensor(out[5]),
|
||||
"input_stage3": torch.LongTensor(out[6]),
|
||||
"target_stage3": torch.LongTensor(out[7]),
|
||||
"encoder_stage3": torch.LongTensor(out[8]),
|
||||
"struct_inject": self.build_struct_inject(map_np, item['outerWall']),
|
||||
"target_density": self.build_target_density(item['map'])
|
||||
}
|
||||
|
||||
def random_sample_map(self, idx: int | None = None) -> dict:
|
||||
if idx is None:
|
||||
idx = random.randrange(len(self.data))
|
||||
|
||||
item = self.data[idx]
|
||||
map_np = np.array(item['map'], dtype=np.int64)
|
||||
|
||||
enc1, enc2, enc3 = self.build_encoder_inputs(map_np)
|
||||
sample = {
|
||||
"encoder_stage1": torch.LongTensor(enc1),
|
||||
"encoder_stage2": torch.LongTensor(enc2),
|
||||
"encoder_stage3": torch.LongTensor(enc3),
|
||||
"struct_inject": self.build_struct_inject(map_np, item['outerWall']),
|
||||
"target_density": self.build_target_density(item['map']),
|
||||
"raw_map": torch.LongTensor(map_np)
|
||||
}
|
||||
sample['sample_idx'] = idx
|
||||
return sample
|
||||
|
||||
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
|
||||
# 将指定 tile ID 替换为 floor(0),原地修改
|
||||
for t in tiles:
|
||||
@ -179,30 +234,4 @@ class GinkaSeperatedDataset(Dataset):
|
||||
else:
|
||||
out = self.apply_subset3(map_np)
|
||||
|
||||
sym_h, sym_v, sym_c = compute_symmetry(map_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
||||
cond_outer = item['outerWall']
|
||||
struct_inject = torch.LongTensor([cond_sym, cond_outer])
|
||||
|
||||
m = item['map']
|
||||
target_density = torch.FloatTensor([
|
||||
self.count_tile(m, self.WALL) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.DOOR) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.MONSTER) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.ENTRANCE) / self.MAP_SIZE,
|
||||
self.count_tile(m, self.RESOURCE) / self.MAP_SIZE,
|
||||
])
|
||||
|
||||
return {
|
||||
"input_stage1": torch.LongTensor(out[0]),
|
||||
"target_stage1": torch.LongTensor(out[1]),
|
||||
"encoder_stage1": torch.LongTensor(out[2]),
|
||||
"input_stage2": torch.LongTensor(out[3]),
|
||||
"target_stage2": torch.LongTensor(out[4]),
|
||||
"encoder_stage2": torch.LongTensor(out[5]),
|
||||
"input_stage3": torch.LongTensor(out[6]),
|
||||
"target_stage3": torch.LongTensor(out[7]),
|
||||
"encoder_stage3": torch.LongTensor(out[8]),
|
||||
"struct_inject": struct_inject,
|
||||
"target_density": target_density
|
||||
}
|
||||
return self.pack_sample(item, map_np, out)
|
||||
|
||||
@ -38,7 +38,7 @@ from shared.image import matrix_to_image_cv
|
||||
# 共用 VQ-VAE 超参
|
||||
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
||||
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
||||
VQ_K = 16 # codebook 大小(离散码本条目数)
|
||||
VQ_K = 32 # codebook 大小(离散码本条目数)
|
||||
VQ_D_Z = 64 # 码字维度
|
||||
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
||||
@ -157,21 +157,22 @@ def build_model(device: torch.device):
|
||||
z_seq_len=VQ_L
|
||||
).to(device)
|
||||
|
||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
||||
all_params = (
|
||||
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
||||
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
|
||||
)
|
||||
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
# 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰
|
||||
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizers = (quantizer1, quantizer2, quantizer3)
|
||||
|
||||
# 九个模块参数合并到同一优化器,端到端联合训练
|
||||
all_params = (
|
||||
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
||||
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) +
|
||||
list(quantizer1.parameters()) + list(quantizer2.parameters()) + list(quantizer3.parameters())
|
||||
)
|
||||
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
|
||||
|
||||
def cross_entropy_loss(logits, target):
|
||||
@ -267,28 +268,6 @@ def sample_reference_inputs(
|
||||
|
||||
return sampled_reference
|
||||
|
||||
def random_struct(device: torch.device) -> torch.Tensor:
|
||||
# 随机采样一组结构参量,用于无条件自由生成
|
||||
# struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)]
|
||||
cond_sym = random.randint(0, 7) # 地图对称类型
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廈
|
||||
return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device)
|
||||
|
||||
def random_target_density(density_stats: dict, device: torch.device) -> torch.Tensor:
|
||||
# 从训练集真实密度范围中采样 wall / door / monster / entrance / resource 目标密度
|
||||
wall_density = random.uniform(density_stats["wall_min_density"], density_stats["wall_max_density"])
|
||||
door_density = random.uniform(density_stats["door_min_density"], density_stats["door_max_density"])
|
||||
monster_density = random.uniform(density_stats["monster_min_density"], density_stats["monster_max_density"])
|
||||
entrance_density = random.uniform(density_stats["entrance_min_density"], density_stats["entrance_max_density"])
|
||||
resource_density = random.uniform(density_stats["resource_min_density"], density_stats["resource_max_density"])
|
||||
return torch.FloatTensor([
|
||||
wall_density,
|
||||
door_density,
|
||||
monster_density,
|
||||
entrance_density,
|
||||
resource_density,
|
||||
]).unsqueeze(0).to(device)
|
||||
|
||||
def compute_remaining(
|
||||
current: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
@ -416,50 +395,6 @@ def maskgit_sample(
|
||||
|
||||
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
def full_generate_random_z(
|
||||
input: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
quantizer1, quantizer2, quantizer3 = quantizers
|
||||
|
||||
with torch.no_grad():
|
||||
z1 = quantizer1.sample(1, VQ_L, device)
|
||||
z2 = quantizer2.sample(1, VQ_L, device)
|
||||
z3 = quantizer3.sample(1, VQ_L, device)
|
||||
|
||||
# stage1:生成墙壁骨架
|
||||
pred1_np = maskgit_sample(
|
||||
mg1, input.clone(), z1, struct, target_density, 1,
|
||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||
)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
||||
|
||||
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z2, struct, target_density, 2,
|
||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
# stage3:填充 resource(3)
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z3, struct, target_density, 3,
|
||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
|
||||
return pred1_np, merged12, merged123
|
||||
|
||||
def full_generate_specific_z(
|
||||
input: torch.Tensor,
|
||||
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
@ -474,7 +409,7 @@ def full_generate_specific_z(
|
||||
with torch.no_grad():
|
||||
z1, z2, z3 = z_q
|
||||
|
||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||
# 三阶段级联生成,但使用给定的 z
|
||||
pred1_np = maskgit_sample(
|
||||
mg1, input.clone(), z1, struct, target_density, 1,
|
||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||
@ -500,11 +435,11 @@ def full_generate_specific_z(
|
||||
|
||||
return pred1_np, merged12, merged123
|
||||
|
||||
def annotate(img: np.ndarray, text: str) -> np.ndarray:
|
||||
def annotate(img: np.ndarray, text: str, y: int = 14) -> np.ndarray:
|
||||
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
|
||||
img = img.copy()
|
||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
|
||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
||||
return img
|
||||
|
||||
def annotate_labels(
|
||||
@ -531,6 +466,40 @@ def rand_keep() -> tuple[bool, bool, bool]:
|
||||
def keep_label(kf: tuple[bool, bool, bool]) -> str:
|
||||
return 'fix' if kf[0] else 'free'
|
||||
|
||||
def build_dataset_sample_case(
|
||||
dataset: GinkaSeperatedDataset,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
idx: int | None = None
|
||||
) -> dict:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
sample = dataset.random_sample_map(idx=idx)
|
||||
|
||||
enc1_t = sample["encoder_stage1"].to(device).reshape(1, MAP_SIZE)
|
||||
enc2_t = sample["encoder_stage2"].to(device).reshape(1, MAP_SIZE)
|
||||
enc3_t = sample["encoder_stage3"].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = sample["struct_inject"].to(device).reshape(1, -1)
|
||||
target_density_t = sample["target_density"].to(device).reshape(1, -1)
|
||||
|
||||
with torch.no_grad():
|
||||
z_e1 = vq1(enc1_t)
|
||||
z_e2 = vq2(enc2_t)
|
||||
z_e3 = vq3(enc3_t)
|
||||
z_q, commit_loss, code_hits = quantize_stage_latents(
|
||||
quantizers, z_e1, z_e2, z_e3
|
||||
)
|
||||
|
||||
return {
|
||||
"sample": sample,
|
||||
"struct": struct_t,
|
||||
"target_density": target_density_t,
|
||||
"z_q": z_q,
|
||||
"sample_idx": sample["sample_idx"]
|
||||
}
|
||||
|
||||
def sample_case_label(case: dict) -> str:
|
||||
return f"train#{case['sample_idx']}"
|
||||
|
||||
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
||||
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||
SEP = 3
|
||||
@ -618,52 +587,13 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 验证可视化 part3:2×3 网格;行1=参考输入+相同 struct 随机 z 生成,行2=随机 struct 生成
|
||||
def visualize_part3(batch, models, device, tile_dict, density_stats: dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_ref = batch["struct_inject"][0:1].to(device)
|
||||
target_density_ref = batch["target_density"][0:1].to(device)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
target_density_cpu = batch["target_density"][0]
|
||||
|
||||
row1 = [to_img(inp1_np)]
|
||||
for _ in range(2):
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(
|
||||
inp1_t, struct_ref, target_density_ref, models, device, keep_fixed=kf
|
||||
)
|
||||
row1.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
|
||||
row2 = []
|
||||
for _ in range(3):
|
||||
kf = rand_keep()
|
||||
rnd_struct = random_struct(device)
|
||||
rnd_target_density = random_target_density(density_stats, device)
|
||||
_, _, merged123 = full_generate_random_z(
|
||||
inp1_t, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
|
||||
)
|
||||
row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
|
||||
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 验证可视化 part4:2×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
|
||||
def visualize_part4(models, device, tile_dict, density_stats: dict):
|
||||
# 验证可视化 part4:2×3 网格;保留稀疏墙壁种子,但 z 与标签来自训练集样本
|
||||
def visualize_part4(
|
||||
train_dataset: GinkaSeperatedDataset,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
tile_dict
|
||||
):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
@ -680,15 +610,21 @@ def visualize_part4(models, device, tile_dict, density_stats: dict):
|
||||
|
||||
results = []
|
||||
for _ in range(5):
|
||||
case = build_dataset_sample_case(train_dataset, models, device)
|
||||
kf = rand_keep()
|
||||
rnd_struct = random_struct(device)
|
||||
rnd_target_density = random_target_density(density_stats, device)
|
||||
_, _, merged123 = full_generate_random_z(
|
||||
seed, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
|
||||
sample = case["sample"]
|
||||
_, _, merged123 = full_generate_specific_z(
|
||||
seed, case["z_q"], case["struct"], case["target_density"],
|
||||
models, device, keep_fixed=kf
|
||||
)
|
||||
result = annotate_labels(
|
||||
to_img(merged123), sample["struct_inject"], sample["target_density"]
|
||||
)
|
||||
results.append(
|
||||
annotate(result, f"{sample_case_label(case)} {keep_label(kf)}", y=50)
|
||||
)
|
||||
results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
|
||||
|
||||
row1 = [to_img(seed_np)] + results[:2]
|
||||
row1 = [annotate(to_img(seed_np), 'seed')] + results[:2]
|
||||
row2 = results[2:]
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
@ -702,92 +638,20 @@ def visualize_part4(models, device, tile_dict, density_stats: dict):
|
||||
def visualize_validate(
|
||||
batch, logits1, logits2, logits3, z_q,
|
||||
models: list[torch.nn.Module], device: torch.device, tile_dict,
|
||||
density_stats: dict, epoch: int, batch_idx: int
|
||||
train_dataset: GinkaSeperatedDataset, epoch: int, batch_idx: int
|
||||
):
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict, density_stats))
|
||||
cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict))
|
||||
|
||||
# 密度对照图:随机种子+随机结构,5 张随机密度生成,2×3 网格(左上角为种子图)
|
||||
def visualize_density_cmp(models, device, tile_dict, density_stats: dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
|
||||
seed[0, wall_pos] = 1
|
||||
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
rnd_struct = random_struct(device)
|
||||
struct_cpu = rnd_struct[0].cpu()
|
||||
gen_imgs = []
|
||||
for _ in range(5):
|
||||
rnd_target_density = random_target_density(density_stats, device)
|
||||
target_density_cpu = rnd_target_density[0].cpu()
|
||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_target_density, models, device)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
row1 = [to_img(seed_np)] + gen_imgs[:2]
|
||||
row2 = gen_imgs[2:]
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 固定 z 和结构条件,扫描 5 个不同墙壁目标密度,2×3 网格(左上角为参考地图)
|
||||
def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
base_target_density = batch["target_density"][0:1].to(device)
|
||||
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
|
||||
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
gen_imgs = []
|
||||
wall_count_values = [20, 35, 50, 65, 80]
|
||||
for wall_count in wall_count_values:
|
||||
fixed_target_density = base_target_density.clone()
|
||||
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
|
||||
target_density_cpu = fixed_target_density[0].cpu()
|
||||
_, _, merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q_single, struct_t, fixed_target_density, models, device
|
||||
)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
||||
row2 = gen_imgs[2:]
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part4(train_dataset, models, device, tile_dict))
|
||||
|
||||
def validate(
|
||||
dataloader: DataLoader,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
tile_dict,
|
||||
density_stats: dict,
|
||||
train_dataset: GinkaSeperatedDataset,
|
||||
epoch: int
|
||||
):
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
@ -885,7 +749,7 @@ def validate(
|
||||
# 每个 batch 生成三种可视化图(val/full/rand)
|
||||
visualize_validate(
|
||||
batch, logits1, logits2, logits3, z_q,
|
||||
models, device, tile_dict, density_stats, epoch, idx
|
||||
models, device, tile_dict, train_dataset, epoch, idx
|
||||
)
|
||||
idx += 1
|
||||
|
||||
@ -896,12 +760,6 @@ def validate(
|
||||
avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0
|
||||
tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}")
|
||||
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图
|
||||
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict, density_stats))
|
||||
cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict, density_stats))
|
||||
|
||||
# 恢复训练模式
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
m.train()
|
||||
@ -939,15 +797,9 @@ def train(device: torch.device):
|
||||
mg1.load_state_dict(ckpt["mg1"])
|
||||
mg2.load_state_dict(ckpt["mg2"])
|
||||
mg3.load_state_dict(ckpt["mg3"])
|
||||
if "quantizer1" in ckpt:
|
||||
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||
elif "quantizer" in ckpt:
|
||||
quantizer1.load_state_dict(ckpt["quantizer"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer"])
|
||||
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
|
||||
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
||||
if args.load_optim and "optimizer" in ckpt:
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
@ -1080,7 +932,9 @@ def train(device: torch.device):
|
||||
|
||||
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
||||
if (epoch + 1) % CHECKPOINT == 0:
|
||||
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
|
||||
losses = validate(
|
||||
dataloader_val, models, device, tile_dict, dataset, epoch + 1
|
||||
)
|
||||
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
|
||||
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
|
||||
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total
|
||||
|
||||
@ -4,19 +4,29 @@ import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, K: int, d_z: int):
|
||||
"""
|
||||
Args:
|
||||
K: codebook 大小(码字数量)
|
||||
d_z: 码字嵌入维度
|
||||
temp: 软分配 softmax 温度,越小越接近 hard assignment
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
K: int,
|
||||
d_z: int,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.d_z = d_z
|
||||
self.decay = decay
|
||||
self.epsilon = epsilon
|
||||
|
||||
self.codebook = nn.Embedding(K, d_z)
|
||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||
self.codebook.weight.requires_grad_(False)
|
||||
|
||||
# EMA 统计量:码字访问次数与对应编码向量和。
|
||||
self.register_buffer("ema_cluster_size", torch.ones(K))
|
||||
self.register_buffer(
|
||||
"ema_weight",
|
||||
self.codebook.weight.detach().clone()
|
||||
)
|
||||
|
||||
def codebook_stats(
|
||||
self, indices: torch.Tensor
|
||||
@ -31,19 +41,32 @@ class VectorQuantizer(nn.Module):
|
||||
usage_count = one_hot.sum(dim=0)
|
||||
return perplexity, usage_rate, usage_count
|
||||
|
||||
def ema_update(self, z_flat: torch.Tensor, flat_indices: torch.Tensor):
|
||||
one_hot = F.one_hot(flat_indices, num_classes=self.K).type_as(z_flat)
|
||||
cluster_size = one_hot.sum(dim=0)
|
||||
embed_sum = one_hot.transpose(0, 1) @ z_flat
|
||||
|
||||
self.ema_cluster_size.mul_(self.decay).add_(
|
||||
cluster_size,
|
||||
alpha=1.0 - self.decay
|
||||
)
|
||||
self.ema_weight.mul_(self.decay).add_(
|
||||
embed_sum,
|
||||
alpha=1.0 - self.decay
|
||||
)
|
||||
|
||||
total_count = self.ema_cluster_size.sum()
|
||||
normalized_cluster_size = (
|
||||
(self.ema_cluster_size + self.epsilon) /
|
||||
(total_count + self.K * self.epsilon) * total_count
|
||||
)
|
||||
normalized_weight = self.ema_weight / normalized_cluster_size.unsqueeze(1)
|
||||
self.codebook.weight.data.copy_(normalized_weight)
|
||||
|
||||
def forward(
|
||||
self, z_e: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# z_e: [B, L, d_z]
|
||||
"""
|
||||
Args:
|
||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
||||
|
||||
Returns:
|
||||
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
|
||||
indices: [B, L] 每个位置对应的码字索引
|
||||
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
|
||||
"""
|
||||
B, L, d_z = z_e.shape
|
||||
|
||||
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
|
||||
@ -58,10 +81,10 @@ class VectorQuantizer(nn.Module):
|
||||
distances = ze_square + ek_square - 2 * mul
|
||||
|
||||
# Hard assignment:取最近码字索引
|
||||
indices = distances.argmin(dim=1) # [B*L]
|
||||
flat_indices = distances.argmin(dim=1) # [B*L]
|
||||
|
||||
# 量化向量
|
||||
z_q_flat = self.codebook(indices) # [B*L, d_z]
|
||||
z_q_flat = self.codebook(flat_indices) # [B*L, d_z]
|
||||
z_q = z_q_flat.reshape(B, L, d_z)
|
||||
|
||||
# 直通估计:前向传 z_q,反向传 z_e 的梯度
|
||||
@ -70,7 +93,11 @@ class VectorQuantizer(nn.Module):
|
||||
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
|
||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||
|
||||
indices = indices.reshape(B, L)
|
||||
# 训练时使用 EMA 更新码本;验证与推理阶段保持码本固定。
|
||||
if self.training and z_e.requires_grad:
|
||||
self.ema_update(z_flat.detach(), flat_indices.detach())
|
||||
|
||||
indices = flat_indices.reshape(B, L)
|
||||
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
|
||||
return z_q_st, indices, commit_loss, perplexity, usage_count
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user