mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 18:31:13 +08:00
feat: 从训练集随机抽样生成,而不是完全随机采样
This commit is contained in:
parent
416aa4dd72
commit
14ee52fb2f
@ -71,6 +71,61 @@ class GinkaSeperatedDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
|
def build_struct_inject(self, map_np: np.ndarray, outer_wall: int) -> torch.Tensor:
|
||||||
|
sym_h, sym_v, sym_c = compute_symmetry(map_np)
|
||||||
|
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
||||||
|
return torch.LongTensor([cond_sym, outer_wall])
|
||||||
|
|
||||||
|
def build_target_density(self, map_data: list) -> torch.Tensor:
|
||||||
|
return torch.FloatTensor([
|
||||||
|
self.count_tile(map_data, self.WALL) / self.MAP_SIZE,
|
||||||
|
self.count_tile(map_data, self.DOOR) / self.MAP_SIZE,
|
||||||
|
self.count_tile(map_data, self.MONSTER) / self.MAP_SIZE,
|
||||||
|
self.count_tile(map_data, self.ENTRANCE) / self.MAP_SIZE,
|
||||||
|
self.count_tile(map_data, self.RESOURCE) / self.MAP_SIZE
|
||||||
|
])
|
||||||
|
|
||||||
|
def build_encoder_inputs(self, raw: np.ndarray) -> tuple:
|
||||||
|
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw.copy())
|
||||||
|
enc1 = target1.copy()
|
||||||
|
enc2 = inp2.copy()
|
||||||
|
enc3 = raw.copy()
|
||||||
|
return enc1, enc2, enc3
|
||||||
|
|
||||||
|
def pack_sample(self, item: dict, map_np: np.ndarray, out: tuple) -> dict:
|
||||||
|
return {
|
||||||
|
"input_stage1": torch.LongTensor(out[0]),
|
||||||
|
"target_stage1": torch.LongTensor(out[1]),
|
||||||
|
"encoder_stage1": torch.LongTensor(out[2]),
|
||||||
|
"input_stage2": torch.LongTensor(out[3]),
|
||||||
|
"target_stage2": torch.LongTensor(out[4]),
|
||||||
|
"encoder_stage2": torch.LongTensor(out[5]),
|
||||||
|
"input_stage3": torch.LongTensor(out[6]),
|
||||||
|
"target_stage3": torch.LongTensor(out[7]),
|
||||||
|
"encoder_stage3": torch.LongTensor(out[8]),
|
||||||
|
"struct_inject": self.build_struct_inject(map_np, item['outerWall']),
|
||||||
|
"target_density": self.build_target_density(item['map'])
|
||||||
|
}
|
||||||
|
|
||||||
|
def random_sample_map(self, idx: int | None = None) -> dict:
|
||||||
|
if idx is None:
|
||||||
|
idx = random.randrange(len(self.data))
|
||||||
|
|
||||||
|
item = self.data[idx]
|
||||||
|
map_np = np.array(item['map'], dtype=np.int64)
|
||||||
|
|
||||||
|
enc1, enc2, enc3 = self.build_encoder_inputs(map_np)
|
||||||
|
sample = {
|
||||||
|
"encoder_stage1": torch.LongTensor(enc1),
|
||||||
|
"encoder_stage2": torch.LongTensor(enc2),
|
||||||
|
"encoder_stage3": torch.LongTensor(enc3),
|
||||||
|
"struct_inject": self.build_struct_inject(map_np, item['outerWall']),
|
||||||
|
"target_density": self.build_target_density(item['map']),
|
||||||
|
"raw_map": torch.LongTensor(map_np)
|
||||||
|
}
|
||||||
|
sample['sample_idx'] = idx
|
||||||
|
return sample
|
||||||
|
|
||||||
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
|
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
|
||||||
# 将指定 tile ID 替换为 floor(0),原地修改
|
# 将指定 tile ID 替换为 floor(0),原地修改
|
||||||
for t in tiles:
|
for t in tiles:
|
||||||
@ -179,30 +234,4 @@ class GinkaSeperatedDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
out = self.apply_subset3(map_np)
|
out = self.apply_subset3(map_np)
|
||||||
|
|
||||||
sym_h, sym_v, sym_c = compute_symmetry(map_np)
|
return self.pack_sample(item, map_np, out)
|
||||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
|
|
||||||
cond_outer = item['outerWall']
|
|
||||||
struct_inject = torch.LongTensor([cond_sym, cond_outer])
|
|
||||||
|
|
||||||
m = item['map']
|
|
||||||
target_density = torch.FloatTensor([
|
|
||||||
self.count_tile(m, self.WALL) / self.MAP_SIZE,
|
|
||||||
self.count_tile(m, self.DOOR) / self.MAP_SIZE,
|
|
||||||
self.count_tile(m, self.MONSTER) / self.MAP_SIZE,
|
|
||||||
self.count_tile(m, self.ENTRANCE) / self.MAP_SIZE,
|
|
||||||
self.count_tile(m, self.RESOURCE) / self.MAP_SIZE,
|
|
||||||
])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_stage1": torch.LongTensor(out[0]),
|
|
||||||
"target_stage1": torch.LongTensor(out[1]),
|
|
||||||
"encoder_stage1": torch.LongTensor(out[2]),
|
|
||||||
"input_stage2": torch.LongTensor(out[3]),
|
|
||||||
"target_stage2": torch.LongTensor(out[4]),
|
|
||||||
"encoder_stage2": torch.LongTensor(out[5]),
|
|
||||||
"input_stage3": torch.LongTensor(out[6]),
|
|
||||||
"target_stage3": torch.LongTensor(out[7]),
|
|
||||||
"encoder_stage3": torch.LongTensor(out[8]),
|
|
||||||
"struct_inject": struct_inject,
|
|
||||||
"target_density": target_density
|
|
||||||
}
|
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from shared.image import matrix_to_image_cv
|
|||||||
# 共用 VQ-VAE 超参
|
# 共用 VQ-VAE 超参
|
||||||
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
||||||
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
||||||
VQ_K = 16 # codebook 大小(离散码本条目数)
|
VQ_K = 32 # codebook 大小(离散码本条目数)
|
||||||
VQ_D_Z = 64 # 码字维度
|
VQ_D_Z = 64 # 码字维度
|
||||||
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook)
|
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||||
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
||||||
@ -267,28 +267,6 @@ def sample_reference_inputs(
|
|||||||
|
|
||||||
return sampled_reference
|
return sampled_reference
|
||||||
|
|
||||||
def random_struct(device: torch.device) -> torch.Tensor:
|
|
||||||
# 随机采样一组结构参量,用于无条件自由生成
|
|
||||||
# struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)]
|
|
||||||
cond_sym = random.randint(0, 7) # 地图对称类型
|
|
||||||
cond_outer = random.randint(0, 1) # 是否有外围走廈
|
|
||||||
return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
def random_target_density(density_stats: dict, device: torch.device) -> torch.Tensor:
|
|
||||||
# 从训练集真实密度范围中采样 wall / door / monster / entrance / resource 目标密度
|
|
||||||
wall_density = random.uniform(density_stats["wall_min_density"], density_stats["wall_max_density"])
|
|
||||||
door_density = random.uniform(density_stats["door_min_density"], density_stats["door_max_density"])
|
|
||||||
monster_density = random.uniform(density_stats["monster_min_density"], density_stats["monster_max_density"])
|
|
||||||
entrance_density = random.uniform(density_stats["entrance_min_density"], density_stats["entrance_max_density"])
|
|
||||||
resource_density = random.uniform(density_stats["resource_min_density"], density_stats["resource_max_density"])
|
|
||||||
return torch.FloatTensor([
|
|
||||||
wall_density,
|
|
||||||
door_density,
|
|
||||||
monster_density,
|
|
||||||
entrance_density,
|
|
||||||
resource_density,
|
|
||||||
]).unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
def compute_remaining(
|
def compute_remaining(
|
||||||
current: torch.Tensor,
|
current: torch.Tensor,
|
||||||
target_density: torch.Tensor,
|
target_density: torch.Tensor,
|
||||||
@ -416,50 +394,6 @@ def maskgit_sample(
|
|||||||
|
|
||||||
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||||
|
|
||||||
def full_generate_random_z(
|
|
||||||
input: torch.Tensor,
|
|
||||||
struct: torch.Tensor,
|
|
||||||
target_density: torch.Tensor,
|
|
||||||
models: list[torch.nn.Module],
|
|
||||||
device: torch.device,
|
|
||||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
|
||||||
) -> tuple:
|
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
|
||||||
quantizer1, quantizer2, quantizer3 = quantizers
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
z1 = quantizer1.sample(1, VQ_L, device)
|
|
||||||
z2 = quantizer2.sample(1, VQ_L, device)
|
|
||||||
z3 = quantizer3.sample(1, VQ_L, device)
|
|
||||||
|
|
||||||
# stage1:生成墙壁骨架
|
|
||||||
pred1_np = maskgit_sample(
|
|
||||||
mg1, input.clone(), z1, struct, target_density, 1,
|
|
||||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
|
||||||
)
|
|
||||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
|
||||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
|
||||||
|
|
||||||
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
|
||||||
pred2_np = maskgit_sample(
|
|
||||||
mg2, inp2, z2, struct, target_density, 2,
|
|
||||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
|
||||||
)
|
|
||||||
merged12 = pred1_np.copy()
|
|
||||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
|
||||||
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
|
||||||
inp3[inp3 == 0] = MASK_TOKEN
|
|
||||||
|
|
||||||
# stage3:填充 resource(3)
|
|
||||||
pred3_np = maskgit_sample(
|
|
||||||
mg3, inp3, z3, struct, target_density, 3,
|
|
||||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
|
||||||
)
|
|
||||||
merged123 = merged12.copy()
|
|
||||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
|
||||||
|
|
||||||
return pred1_np, merged12, merged123
|
|
||||||
|
|
||||||
def full_generate_specific_z(
|
def full_generate_specific_z(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
@ -474,7 +408,7 @@ def full_generate_specific_z(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z1, z2, z3 = z_q
|
z1, z2, z3 = z_q
|
||||||
|
|
||||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
# 三阶段级联生成,但使用给定的 z
|
||||||
pred1_np = maskgit_sample(
|
pred1_np = maskgit_sample(
|
||||||
mg1, input.clone(), z1, struct, target_density, 1,
|
mg1, input.clone(), z1, struct, target_density, 1,
|
||||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||||
@ -500,11 +434,11 @@ def full_generate_specific_z(
|
|||||||
|
|
||||||
return pred1_np, merged12, merged123
|
return pred1_np, merged12, merged123
|
||||||
|
|
||||||
def annotate(img: np.ndarray, text: str) -> np.ndarray:
|
def annotate(img: np.ndarray, text: str, y: int = 14) -> np.ndarray:
|
||||||
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
|
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
|
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
|
||||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def annotate_labels(
|
def annotate_labels(
|
||||||
@ -531,6 +465,40 @@ def rand_keep() -> tuple[bool, bool, bool]:
|
|||||||
def keep_label(kf: tuple[bool, bool, bool]) -> str:
|
def keep_label(kf: tuple[bool, bool, bool]) -> str:
|
||||||
return 'fix' if kf[0] else 'free'
|
return 'fix' if kf[0] else 'free'
|
||||||
|
|
||||||
|
def build_dataset_sample_case(
|
||||||
|
dataset: GinkaSeperatedDataset,
|
||||||
|
models: list[torch.nn.Module],
|
||||||
|
device: torch.device,
|
||||||
|
idx: int | None = None
|
||||||
|
) -> dict:
|
||||||
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||||
|
sample = dataset.random_sample_map(idx=idx)
|
||||||
|
|
||||||
|
enc1_t = sample["encoder_stage1"].to(device).reshape(1, MAP_SIZE)
|
||||||
|
enc2_t = sample["encoder_stage2"].to(device).reshape(1, MAP_SIZE)
|
||||||
|
enc3_t = sample["encoder_stage3"].to(device).reshape(1, MAP_SIZE)
|
||||||
|
struct_t = sample["struct_inject"].to(device).reshape(1, -1)
|
||||||
|
target_density_t = sample["target_density"].to(device).reshape(1, -1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
z_e1 = vq1(enc1_t)
|
||||||
|
z_e2 = vq2(enc2_t)
|
||||||
|
z_e3 = vq3(enc3_t)
|
||||||
|
z_q, commit_loss, code_hits = quantize_stage_latents(
|
||||||
|
quantizers, z_e1, z_e2, z_e3
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sample": sample,
|
||||||
|
"struct": struct_t,
|
||||||
|
"target_density": target_density_t,
|
||||||
|
"z_q": z_q,
|
||||||
|
"sample_idx": sample["sample_idx"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample_case_label(case: dict) -> str:
|
||||||
|
return f"train#{case['sample_idx']}"
|
||||||
|
|
||||||
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并)
|
||||||
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
|
||||||
SEP = 3
|
SEP = 3
|
||||||
@ -618,52 +586,13 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
|||||||
grid[y:y + img_h, x:x + img_w] = img
|
grid[y:y + img_h, x:x + img_w] = img
|
||||||
return grid
|
return grid
|
||||||
|
|
||||||
# 验证可视化 part3:2×3 网格;行1=参考输入+相同 struct 随机 z 生成,行2=随机 struct 生成
|
# 验证可视化 part4:2×3 网格;保留稀疏墙壁种子,但 z 与标签来自训练集样本
|
||||||
def visualize_part3(batch, models, device, tile_dict, density_stats: dict):
|
def visualize_part4(
|
||||||
SEP = 3
|
train_dataset: GinkaSeperatedDataset,
|
||||||
TILE_SIZE = 32
|
models: list[torch.nn.Module],
|
||||||
img_h = MAP_H * TILE_SIZE
|
device: torch.device,
|
||||||
img_w = MAP_W * TILE_SIZE
|
tile_dict
|
||||||
|
):
|
||||||
def to_img(mat):
|
|
||||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
|
||||||
|
|
||||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
|
||||||
struct_ref = batch["struct_inject"][0:1].to(device)
|
|
||||||
target_density_ref = batch["target_density"][0:1].to(device)
|
|
||||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
|
||||||
struct_cpu = batch["struct_inject"][0]
|
|
||||||
target_density_cpu = batch["target_density"][0]
|
|
||||||
|
|
||||||
row1 = [to_img(inp1_np)]
|
|
||||||
for _ in range(2):
|
|
||||||
kf = rand_keep()
|
|
||||||
_, _, merged123 = full_generate_random_z(
|
|
||||||
inp1_t, struct_ref, target_density_ref, models, device, keep_fixed=kf
|
|
||||||
)
|
|
||||||
row1.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
|
||||||
|
|
||||||
row2 = []
|
|
||||||
for _ in range(3):
|
|
||||||
kf = rand_keep()
|
|
||||||
rnd_struct = random_struct(device)
|
|
||||||
rnd_target_density = random_target_density(density_stats, device)
|
|
||||||
_, _, merged123 = full_generate_random_z(
|
|
||||||
inp1_t, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
|
|
||||||
)
|
|
||||||
row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
|
|
||||||
|
|
||||||
rows = [row1, row2]
|
|
||||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
|
||||||
for r, row in enumerate(rows):
|
|
||||||
for c, img in enumerate(row):
|
|
||||||
y = SEP + r * (img_h + SEP)
|
|
||||||
x = SEP + c * (img_w + SEP)
|
|
||||||
grid[y:y + img_h, x:x + img_w] = img
|
|
||||||
return grid
|
|
||||||
|
|
||||||
# 验证可视化 part4:2×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
|
|
||||||
def visualize_part4(models, device, tile_dict, density_stats: dict):
|
|
||||||
SEP = 3
|
SEP = 3
|
||||||
TILE_SIZE = 32
|
TILE_SIZE = 32
|
||||||
img_h = MAP_H * TILE_SIZE
|
img_h = MAP_H * TILE_SIZE
|
||||||
@ -680,15 +609,21 @@ def visualize_part4(models, device, tile_dict, density_stats: dict):
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
case = build_dataset_sample_case(train_dataset, models, device)
|
||||||
kf = rand_keep()
|
kf = rand_keep()
|
||||||
rnd_struct = random_struct(device)
|
sample = case["sample"]
|
||||||
rnd_target_density = random_target_density(density_stats, device)
|
_, _, merged123 = full_generate_specific_z(
|
||||||
_, _, merged123 = full_generate_random_z(
|
seed, case["z_q"], case["struct"], case["target_density"],
|
||||||
seed, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
|
models, device, keep_fixed=kf
|
||||||
|
)
|
||||||
|
result = annotate_labels(
|
||||||
|
to_img(merged123), sample["struct_inject"], sample["target_density"]
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
annotate(result, f"{sample_case_label(case)} {keep_label(kf)}", y=50)
|
||||||
)
|
)
|
||||||
results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
|
|
||||||
|
|
||||||
row1 = [to_img(seed_np)] + results[:2]
|
row1 = [annotate(to_img(seed_np), 'seed')] + results[:2]
|
||||||
row2 = results[2:]
|
row2 = results[2:]
|
||||||
rows = [row1, row2]
|
rows = [row1, row2]
|
||||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||||
@ -702,92 +637,20 @@ def visualize_part4(models, device, tile_dict, density_stats: dict):
|
|||||||
def visualize_validate(
|
def visualize_validate(
|
||||||
batch, logits1, logits2, logits3, z_q,
|
batch, logits1, logits2, logits3, z_q,
|
||||||
models: list[torch.nn.Module], device: torch.device, tile_dict,
|
models: list[torch.nn.Module], device: torch.device, tile_dict,
|
||||||
density_stats: dict, epoch: int, batch_idx: int
|
train_dataset: GinkaSeperatedDataset, epoch: int, batch_idx: int
|
||||||
):
|
):
|
||||||
save_dir = f"result/seperated/e{epoch}"
|
save_dir = f"result/seperated/e{epoch}"
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
|
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
|
||||||
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
|
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
|
||||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict, density_stats))
|
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part4(train_dataset, models, device, tile_dict))
|
||||||
cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict))
|
|
||||||
|
|
||||||
# 密度对照图:随机种子+随机结构,5 张随机密度生成,2×3 网格(左上角为种子图)
|
|
||||||
def visualize_density_cmp(models, device, tile_dict, density_stats: dict):
|
|
||||||
SEP = 3
|
|
||||||
TILE_SIZE = 32
|
|
||||||
img_h = MAP_H * TILE_SIZE
|
|
||||||
img_w = MAP_W * TILE_SIZE
|
|
||||||
|
|
||||||
def to_img(mat):
|
|
||||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
|
||||||
|
|
||||||
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
|
|
||||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
|
||||||
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
|
|
||||||
seed[0, wall_pos] = 1
|
|
||||||
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
|
||||||
rnd_struct = random_struct(device)
|
|
||||||
struct_cpu = rnd_struct[0].cpu()
|
|
||||||
gen_imgs = []
|
|
||||||
for _ in range(5):
|
|
||||||
rnd_target_density = random_target_density(density_stats, device)
|
|
||||||
target_density_cpu = rnd_target_density[0].cpu()
|
|
||||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_target_density, models, device)
|
|
||||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
|
||||||
row1 = [to_img(seed_np)] + gen_imgs[:2]
|
|
||||||
row2 = gen_imgs[2:]
|
|
||||||
rows = [row1, row2]
|
|
||||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
|
||||||
for r, row in enumerate(rows):
|
|
||||||
for c, img in enumerate(row):
|
|
||||||
y = SEP + r * (img_h + SEP)
|
|
||||||
x = SEP + c * (img_w + SEP)
|
|
||||||
grid[y:y + img_h, x:x + img_w] = img
|
|
||||||
return grid
|
|
||||||
|
|
||||||
# 固定 z 和结构条件,扫描 5 个不同墙壁目标密度,2×3 网格(左上角为参考地图)
|
|
||||||
def visualize_density_var(batch, z_q, models, device, tile_dict):
|
|
||||||
SEP = 3
|
|
||||||
TILE_SIZE = 32
|
|
||||||
img_h = MAP_H * TILE_SIZE
|
|
||||||
img_w = MAP_W * TILE_SIZE
|
|
||||||
|
|
||||||
def to_img(mat):
|
|
||||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
|
||||||
|
|
||||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
|
||||||
struct_t = batch["struct_inject"][0:1].to(device)
|
|
||||||
struct_cpu = batch["struct_inject"][0]
|
|
||||||
base_target_density = batch["target_density"][0:1].to(device)
|
|
||||||
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
|
|
||||||
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
|
||||||
gen_imgs = []
|
|
||||||
wall_count_values = [20, 35, 50, 65, 80]
|
|
||||||
for wall_count in wall_count_values:
|
|
||||||
fixed_target_density = base_target_density.clone()
|
|
||||||
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
|
|
||||||
target_density_cpu = fixed_target_density[0].cpu()
|
|
||||||
_, _, merged123 = full_generate_specific_z(
|
|
||||||
inp1_t, z_q_single, struct_t, fixed_target_density, models, device
|
|
||||||
)
|
|
||||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
|
||||||
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
|
||||||
row2 = gen_imgs[2:]
|
|
||||||
rows = [row1, row2]
|
|
||||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
|
||||||
for r, row in enumerate(rows):
|
|
||||||
for c, img in enumerate(row):
|
|
||||||
y = SEP + r * (img_h + SEP)
|
|
||||||
x = SEP + c * (img_w + SEP)
|
|
||||||
grid[y:y + img_h, x:x + img_w] = img
|
|
||||||
return grid
|
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
dataloader: DataLoader,
|
dataloader: DataLoader,
|
||||||
models: list[torch.nn.Module],
|
models: list[torch.nn.Module],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
tile_dict,
|
tile_dict,
|
||||||
density_stats: dict,
|
train_dataset: GinkaSeperatedDataset,
|
||||||
epoch: int
|
epoch: int
|
||||||
):
|
):
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||||
@ -885,7 +748,7 @@ def validate(
|
|||||||
# 每个 batch 生成三种可视化图(val/full/rand)
|
# 每个 batch 生成三种可视化图(val/full/rand)
|
||||||
visualize_validate(
|
visualize_validate(
|
||||||
batch, logits1, logits2, logits3, z_q,
|
batch, logits1, logits2, logits3, z_q,
|
||||||
models, device, tile_dict, density_stats, epoch, idx
|
models, device, tile_dict, train_dataset, epoch, idx
|
||||||
)
|
)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
@ -896,12 +759,6 @@ def validate(
|
|||||||
avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0
|
avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0
|
||||||
tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}")
|
tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}")
|
||||||
|
|
||||||
save_dir = f"result/seperated/e{epoch}"
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
# 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图
|
|
||||||
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict, density_stats))
|
|
||||||
cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict, density_stats))
|
|
||||||
|
|
||||||
# 恢复训练模式
|
# 恢复训练模式
|
||||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||||
m.train()
|
m.train()
|
||||||
@ -1080,7 +937,9 @@ def train(device: torch.device):
|
|||||||
|
|
||||||
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
||||||
if (epoch + 1) % CHECKPOINT == 0:
|
if (epoch + 1) % CHECKPOINT == 0:
|
||||||
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
|
losses = validate(
|
||||||
|
dataloader_val, models, device, tile_dict, dataset, epoch + 1
|
||||||
|
)
|
||||||
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
|
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
|
||||||
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
|
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
|
||||||
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total
|
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user