feat: 从训练集随机抽样生成,而不是完全随机采样

This commit is contained in:
unanmed 2026-05-20 18:02:28 +08:00
parent 416aa4dd72
commit 14ee52fb2f
2 changed files with 121 additions and 233 deletions

View File

@ -71,6 +71,61 @@ class GinkaSeperatedDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def build_struct_inject(self, map_np: np.ndarray, outer_wall: int) -> torch.Tensor:
sym_h, sym_v, sym_c = compute_symmetry(map_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
return torch.LongTensor([cond_sym, outer_wall])
def build_target_density(self, map_data: list) -> torch.Tensor:
return torch.FloatTensor([
self.count_tile(map_data, self.WALL) / self.MAP_SIZE,
self.count_tile(map_data, self.DOOR) / self.MAP_SIZE,
self.count_tile(map_data, self.MONSTER) / self.MAP_SIZE,
self.count_tile(map_data, self.ENTRANCE) / self.MAP_SIZE,
self.count_tile(map_data, self.RESOURCE) / self.MAP_SIZE
])
def build_encoder_inputs(self, raw: np.ndarray) -> tuple:
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw.copy())
enc1 = target1.copy()
enc2 = inp2.copy()
enc3 = raw.copy()
return enc1, enc2, enc3
def pack_sample(self, item: dict, map_np: np.ndarray, out: tuple) -> dict:
return {
"input_stage1": torch.LongTensor(out[0]),
"target_stage1": torch.LongTensor(out[1]),
"encoder_stage1": torch.LongTensor(out[2]),
"input_stage2": torch.LongTensor(out[3]),
"target_stage2": torch.LongTensor(out[4]),
"encoder_stage2": torch.LongTensor(out[5]),
"input_stage3": torch.LongTensor(out[6]),
"target_stage3": torch.LongTensor(out[7]),
"encoder_stage3": torch.LongTensor(out[8]),
"struct_inject": self.build_struct_inject(map_np, item['outerWall']),
"target_density": self.build_target_density(item['map'])
}
def random_sample_map(self, idx: int | None = None) -> dict:
if idx is None:
idx = random.randrange(len(self.data))
item = self.data[idx]
map_np = np.array(item['map'], dtype=np.int64)
enc1, enc2, enc3 = self.build_encoder_inputs(map_np)
sample = {
"encoder_stage1": torch.LongTensor(enc1),
"encoder_stage2": torch.LongTensor(enc2),
"encoder_stage3": torch.LongTensor(enc3),
"struct_inject": self.build_struct_inject(map_np, item['outerWall']),
"target_density": self.build_target_density(item['map']),
"raw_map": torch.LongTensor(map_np)
}
sample['sample_idx'] = idx
return sample
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray: def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
# 将指定 tile ID 替换为 floor(0),原地修改 # 将指定 tile ID 替换为 floor(0),原地修改
for t in tiles: for t in tiles:
@ -179,30 +234,4 @@ class GinkaSeperatedDataset(Dataset):
else: else:
out = self.apply_subset3(map_np) out = self.apply_subset3(map_np)
sym_h, sym_v, sym_c = compute_symmetry(map_np) return self.pack_sample(item, map_np, out)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
cond_outer = item['outerWall']
struct_inject = torch.LongTensor([cond_sym, cond_outer])
m = item['map']
target_density = torch.FloatTensor([
self.count_tile(m, self.WALL) / self.MAP_SIZE,
self.count_tile(m, self.DOOR) / self.MAP_SIZE,
self.count_tile(m, self.MONSTER) / self.MAP_SIZE,
self.count_tile(m, self.ENTRANCE) / self.MAP_SIZE,
self.count_tile(m, self.RESOURCE) / self.MAP_SIZE,
])
return {
"input_stage1": torch.LongTensor(out[0]),
"target_stage1": torch.LongTensor(out[1]),
"encoder_stage1": torch.LongTensor(out[2]),
"input_stage2": torch.LongTensor(out[3]),
"target_stage2": torch.LongTensor(out[4]),
"encoder_stage2": torch.LongTensor(out[5]),
"input_stage3": torch.LongTensor(out[6]),
"target_stage3": torch.LongTensor(out[7]),
"encoder_stage3": torch.LongTensor(out[8]),
"struct_inject": struct_inject,
"target_density": target_density
}

View File

@ -38,7 +38,7 @@ from shared.image import matrix_to_image_cv
# 共用 VQ-VAE 超参 # 共用 VQ-VAE 超参
# 三组编码器vq1/vq2/vq3共享相同超参分别对三阶段地图上下文独立编码 # 三组编码器vq1/vq2/vq3共享相同超参分别对三阶段地图上下文独立编码
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3 VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3
VQ_K = 16 # codebook 大小(离散码本条目数) VQ_K = 32 # codebook 大小(离散码本条目数)
VQ_D_Z = 64 # 码字维度 VQ_D_Z = 64 # 码字维度
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用) VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
@ -267,28 +267,6 @@ def sample_reference_inputs(
return sampled_reference return sampled_reference
def random_struct(device: torch.device) -> torch.Tensor:
# 随机采样一组结构参量,用于无条件自由生成
# struct_inject 格式:[cond_sym(0-7), cond_outer(0-1)]
cond_sym = random.randint(0, 7) # 地图对称类型
cond_outer = random.randint(0, 1) # 是否有外围走廈
return torch.LongTensor([cond_sym, cond_outer]).unsqueeze(0).to(device)
def random_target_density(density_stats: dict, device: torch.device) -> torch.Tensor:
# 从训练集真实密度范围中采样 wall / door / monster / entrance / resource 目标密度
wall_density = random.uniform(density_stats["wall_min_density"], density_stats["wall_max_density"])
door_density = random.uniform(density_stats["door_min_density"], density_stats["door_max_density"])
monster_density = random.uniform(density_stats["monster_min_density"], density_stats["monster_max_density"])
entrance_density = random.uniform(density_stats["entrance_min_density"], density_stats["entrance_max_density"])
resource_density = random.uniform(density_stats["resource_min_density"], density_stats["resource_max_density"])
return torch.FloatTensor([
wall_density,
door_density,
monster_density,
entrance_density,
resource_density,
]).unsqueeze(0).to(device)
def compute_remaining( def compute_remaining(
current: torch.Tensor, current: torch.Tensor,
target_density: torch.Tensor, target_density: torch.Tensor,
@ -416,50 +394,6 @@ def maskgit_sample(
return current[0].cpu().numpy().reshape(MAP_H, MAP_W) return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
def full_generate_random_z(
input: torch.Tensor,
struct: torch.Tensor,
target_density: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
quantizer1, quantizer2, quantizer3 = quantizers
with torch.no_grad():
z1 = quantizer1.sample(1, VQ_L, device)
z2 = quantizer2.sample(1, VQ_L, device)
z3 = quantizer3.sample(1, VQ_L, device)
# stage1生成墙壁骨架
pred1_np = maskgit_sample(
mg1, input.clone(), z1, struct, target_density, 1,
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
)
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
# stage2在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
pred2_np = maskgit_sample(
mg2, inp2, z2, struct, target_density, 2,
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
)
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
# stage3填充 resource(3)
pred3_np = maskgit_sample(
mg3, inp3, z3, struct, target_density, 3,
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
)
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
return pred1_np, merged12, merged123
def full_generate_specific_z( def full_generate_specific_z(
input: torch.Tensor, input: torch.Tensor,
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor], z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
@ -474,7 +408,7 @@ def full_generate_specific_z(
with torch.no_grad(): with torch.no_grad():
z1, z2, z3 = z_q z1, z2, z3 = z_q
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z # 三阶段级联生成,但使用给定的 z
pred1_np = maskgit_sample( pred1_np = maskgit_sample(
mg1, input.clone(), z1, struct, target_density, 1, mg1, input.clone(), z1, struct, target_density, 1,
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0] GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
@ -500,11 +434,11 @@ def full_generate_specific_z(
return pred1_np, merged12, merged123 return pred1_np, merged12, merged123
def annotate(img: np.ndarray, text: str) -> np.ndarray: def annotate(img: np.ndarray, text: str, y: int = 14) -> np.ndarray:
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读) # 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
img = img.copy() img = img.copy()
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2) cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
return img return img
def annotate_labels( def annotate_labels(
@ -531,6 +465,40 @@ def rand_keep() -> tuple[bool, bool, bool]:
def keep_label(kf: tuple[bool, bool, bool]) -> str: def keep_label(kf: tuple[bool, bool, bool]) -> str:
return 'fix' if kf[0] else 'free' return 'fix' if kf[0] else 'free'
def build_dataset_sample_case(
dataset: GinkaSeperatedDataset,
models: list[torch.nn.Module],
device: torch.device,
idx: int | None = None
) -> dict:
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
sample = dataset.random_sample_map(idx=idx)
enc1_t = sample["encoder_stage1"].to(device).reshape(1, MAP_SIZE)
enc2_t = sample["encoder_stage2"].to(device).reshape(1, MAP_SIZE)
enc3_t = sample["encoder_stage3"].to(device).reshape(1, MAP_SIZE)
struct_t = sample["struct_inject"].to(device).reshape(1, -1)
target_density_t = sample["target_density"].to(device).reshape(1, -1)
with torch.no_grad():
z_e1 = vq1(enc1_t)
z_e2 = vq2(enc2_t)
z_e3 = vq3(enc3_t)
z_q, commit_loss, code_hits = quantize_stage_latents(
quantizers, z_e1, z_e2, z_e3
)
return {
"sample": sample,
"struct": struct_t,
"target_density": target_density_t,
"z_q": z_q,
"sample_idx": sample["sample_idx"]
}
def sample_case_label(case: dict) -> str:
return f"train#{case['sample_idx']}"
# 验证可视化 part13×3 网格行1=编码器输入行2=掩码输入行3=三阶段预测(合并) # 验证可视化 part13×3 网格行1=编码器输入行2=掩码输入行3=三阶段预测(合并)
def visualize_part1(batch, logits1, logits2, logits3, tile_dict): def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
SEP = 3 SEP = 3
@ -618,52 +586,13 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
grid[y:y + img_h, x:x + img_w] = img grid[y:y + img_h, x:x + img_w] = img
return grid return grid
# 验证可视化 part32×3 网格行1=参考输入+相同 struct 随机 z 生成行2=随机 struct 生成 # 验证可视化 part42×3 网格;保留稀疏墙壁种子,但 z 与标签来自训练集样本
def visualize_part3(batch, models, device, tile_dict, density_stats: dict): def visualize_part4(
SEP = 3 train_dataset: GinkaSeperatedDataset,
TILE_SIZE = 32 models: list[torch.nn.Module],
img_h = MAP_H * TILE_SIZE device: torch.device,
img_w = MAP_W * TILE_SIZE tile_dict
):
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_ref = batch["struct_inject"][0:1].to(device)
target_density_ref = batch["target_density"][0:1].to(device)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
struct_cpu = batch["struct_inject"][0]
target_density_cpu = batch["target_density"][0]
row1 = [to_img(inp1_np)]
for _ in range(2):
kf = rand_keep()
_, _, merged123 = full_generate_random_z(
inp1_t, struct_ref, target_density_ref, models, device, keep_fixed=kf
)
row1.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
row2 = []
for _ in range(3):
kf = rand_keep()
rnd_struct = random_struct(device)
rnd_target_density = random_target_density(density_stats, device)
_, _, merged123 = full_generate_random_z(
inp1_t, rnd_struct, rnd_target_density, models, device, keep_fixed=kf
)
row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part42×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
def visualize_part4(models, device, tile_dict, density_stats: dict):
SEP = 3 SEP = 3
TILE_SIZE = 32 TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE img_h = MAP_H * TILE_SIZE
@ -680,15 +609,21 @@ def visualize_part4(models, device, tile_dict, density_stats: dict):
results = [] results = []
for _ in range(5): for _ in range(5):
case = build_dataset_sample_case(train_dataset, models, device)
kf = rand_keep() kf = rand_keep()
rnd_struct = random_struct(device) sample = case["sample"]
rnd_target_density = random_target_density(density_stats, device) _, _, merged123 = full_generate_specific_z(
_, _, merged123 = full_generate_random_z( seed, case["z_q"], case["struct"], case["target_density"],
seed, rnd_struct, rnd_target_density, models, device, keep_fixed=kf models, device, keep_fixed=kf
)
result = annotate_labels(
to_img(merged123), sample["struct_inject"], sample["target_density"]
)
results.append(
annotate(result, f"{sample_case_label(case)} {keep_label(kf)}", y=50)
) )
results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_target_density[0].cpu()))
row1 = [to_img(seed_np)] + results[:2] row1 = [annotate(to_img(seed_np), 'seed')] + results[:2]
row2 = results[2:] row2 = results[2:]
rows = [row1, row2] rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
@ -702,92 +637,20 @@ def visualize_part4(models, device, tile_dict, density_stats: dict):
def visualize_validate( def visualize_validate(
batch, logits1, logits2, logits3, z_q, batch, logits1, logits2, logits3, z_q,
models: list[torch.nn.Module], device: torch.device, tile_dict, models: list[torch.nn.Module], device: torch.device, tile_dict,
density_stats: dict, epoch: int, batch_idx: int train_dataset: GinkaSeperatedDataset, epoch: int, batch_idx: int
): ):
save_dir = f"result/seperated/e{epoch}" save_dir = f"result/seperated/e{epoch}"
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict)) cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict)) cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict, density_stats)) cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part4(train_dataset, models, device, tile_dict))
cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict))
# 密度对照图:随机种子+随机结构5 张随机密度生成2×3 网格(左上角为种子图)
def visualize_density_cmp(models, device, tile_dict, density_stats: dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
seed[0, wall_pos] = 1
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
rnd_struct = random_struct(device)
struct_cpu = rnd_struct[0].cpu()
gen_imgs = []
for _ in range(5):
rnd_target_density = random_target_density(density_stats, device)
target_density_cpu = rnd_target_density[0].cpu()
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_target_density, models, device)
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
row1 = [to_img(seed_np)] + gen_imgs[:2]
row2 = gen_imgs[2:]
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 固定 z 和结构条件,扫描 5 个不同墙壁目标密度2×3 网格(左上角为参考地图)
def visualize_density_var(batch, z_q, models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_t = batch["struct_inject"][0:1].to(device)
struct_cpu = batch["struct_inject"][0]
base_target_density = batch["target_density"][0:1].to(device)
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
gen_imgs = []
wall_count_values = [20, 35, 50, 65, 80]
for wall_count in wall_count_values:
fixed_target_density = base_target_density.clone()
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
target_density_cpu = fixed_target_density[0].cpu()
_, _, merged123 = full_generate_specific_z(
inp1_t, z_q_single, struct_t, fixed_target_density, models, device
)
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
row1 = [to_img(ref_np)] + gen_imgs[:2]
row2 = gen_imgs[2:]
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
def validate( def validate(
dataloader: DataLoader, dataloader: DataLoader,
models: list[torch.nn.Module], models: list[torch.nn.Module],
device: torch.device, device: torch.device,
tile_dict, tile_dict,
density_stats: dict, train_dataset: GinkaSeperatedDataset,
epoch: int epoch: int
): ):
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
@ -885,7 +748,7 @@ def validate(
# 每个 batch 生成三种可视化图val/full/rand # 每个 batch 生成三种可视化图val/full/rand
visualize_validate( visualize_validate(
batch, logits1, logits2, logits3, z_q, batch, logits1, logits2, logits3, z_q,
models, device, tile_dict, density_stats, epoch, idx models, device, tile_dict, train_dataset, epoch, idx
) )
idx += 1 idx += 1
@ -896,12 +759,6 @@ def validate(
avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0 avg_over = density_metrics[tile_id]["over"] / count if count > 0 else 0.0
tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}") tqdm.write(f" density {tile_names[tile_id]}: mae={avg_mae:.4f} over={avg_over:.4f}")
save_dir = f"result/seperated/e{epoch}"
os.makedirs(save_dir, exist_ok=True)
# 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict, density_stats))
cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict, density_stats))
# 恢复训练模式 # 恢复训练模式
for m in [vq1, vq2, vq3, mg1, mg2, mg3]: for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
m.train() m.train()
@ -1080,7 +937,9 @@ def train(device: torch.device):
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存 # 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
if (epoch + 1) % CHECKPOINT == 0: if (epoch + 1) % CHECKPOINT == 0:
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1) losses = validate(
dataloader_val, models, device, tile_dict, dataset, epoch + 1
)
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total loss2_weighted = STAGE2_CE_WEIGHT * loss2_total