From 164bb24823c0fb8b68a340fc95771061dd98bd46 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 23 Apr 2026 18:37:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=81=94=E5=90=88=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E9=A2=84=E8=A7=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- ginka/train_joint.py | 124 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 116 insertions(+), 8 deletions(-) diff --git a/ginka/train_joint.py b/ginka/train_joint.py index 68cbc57..e709fd7 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -4,10 +4,12 @@ import os import sys from datetime import datetime +import cv2 import numpy as np import torch import torch.nn.functional as F import torch.optim as optim +from perlin_numpy import generate_fractal_noise_2d from torch.utils.data import DataLoader from tqdm import tqdm @@ -15,26 +17,39 @@ from .dataset import GinkaJointDataset from .heatmap.diffusion import Diffusion from .heatmap.model import GinkaHeatmapModel from .maskGIT.model import GinkaMaskGIT +from .utils import nms_sampling +from shared.image import matrix_to_image_cv -BATCH_SIZE = 64 -VAL_BATCH_DIVIDER = 64 +# 地图与 token 基础配置 NUM_CLASSES = 16 MASK_TOKEN = 15 -GENERATE_STEP = 8 MAP_W = 13 MAP_H = 13 HEATMAP_CHANNEL = 9 +GENERATE_STEP = 8 + +# 训练批次与损失配置 +BATCH_SIZE = 64 +VAL_BATCH_DIVIDER = 64 LABEL_SMOOTHING = 0 +CE_WEIGHT = 0.5 # 联合训练里 MaskGIT 监督项的权重 +DROP_RATE = 0.2 # CFG 训练时随机丢弃条件热力图的概率 + +# MaskGIT 模型结构 NUM_LAYERS = 4 D_MODEL = 192 + +# Diffusion 模型结构与噪声过程 NUM_LAYERS_DIFFUSION = 4 D_MODEL_DIFFUSION = 128 T_DIFFUSION = 100 MIN_MASK = 0 MAX_MASK = 1 -CE_WEIGHT = 0.5 -DROP_RATE = 0.2 + +# 验证预览配置 +PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度 +RANDOM_PREVIEW_COUNT = 5 # 每次验证额外生成的随机预览数量 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -43,11 +58,13 @@ device = torch.device( ) os.makedirs("result", exist_ok=True) os.makedirs("result/joint", exist_ok=True) +os.makedirs("result/joint_img", exist_ok=True) disable_tqdm = not sys.stdout.isatty() def parse_arguments(): + # 解析联合训练脚本的命令行参数。 parser = argparse.ArgumentParser(description="joint training codes") parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth") @@ -62,6 +79,7 @@ def parse_arguments(): def load_heatmap_checkpoint(model, optimizer, args): + # 加载预训练 Diffusion 权重,并在需要时恢复优化器状态。 if not args.state_heatmap: return @@ -78,12 +96,14 @@ def load_heatmap_checkpoint(model, optimizer, args): def freeze_module(module: torch.nn.Module): + # 冻结模块参数,使其在联合训练中只作为固定监督器使用。 module.eval() for parameter in module.parameters(): parameter.requires_grad = False def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor): + # 根据当前时刻的噪声预测还原 x0 热力图估计。 sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None] sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None] x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab @@ -91,6 +111,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): + # 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。 batch_size, height, width = target_map.shape target_tokens = target_map.view(batch_size, height * width) canvas = torch.full_like(target_tokens, MASK_TOKEN) @@ -101,6 +122,7 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor if current_mask.sum().item() == 0: break + # 保证前向传播可导 logits = maskgit(canvas, generated_heatmap) ce = F.cross_entropy( logits.permute(0, 2, 1), @@ -131,18 +153,91 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor return torch.stack(losses).mean() -def validate(model, maskgit, diffusion, dataloader, ce_weight): +def load_tile_dict(): + # 加载用于可视化地图的图块贴图。 + tile_dict = dict() + for file in os.listdir('tiles'): + name = os.path.splitext(file)[0] + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + return tile_dict + + +def get_nms_sampling_count(): + # 为随机点图预览采样每个通道的点数量。 + return [ + np.random.randint(20, 40), + np.random.randint(10, 20), + np.random.randint(10, 30), + np.random.randint(4, 12), + np.random.randint(4, 12), + np.random.randint(2, 6), + np.random.randint(0, 2), + np.random.randint(1, 3), + np.random.randint(2, 10) + ] + + +def maskgit_generate(maskgit, batch_size: int, heatmap: torch.Tensor): + # 使用冻结的 MaskGIT 把热力图解码为完整地图。 + generated_map = torch.full((batch_size, MAP_H * MAP_W), MASK_TOKEN, device=device) + for step in range(GENERATE_STEP): + logits = maskgit(generated_map, heatmap) + probs = F.softmax(logits, dim=-1) + + dist = torch.distributions.Categorical(probs) + sampled_tiles = dist.sample() + confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1) + + ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2) + num_to_mask = math.floor(ratio * MAP_H * MAP_W) + + if num_to_mask > 0: + _, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False) + sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN) + + generated_map = sampled_tiles + if (generated_map == MASK_TOKEN).sum() == 0: + break + + return generated_map + + +def full_generate(heatmap_model, maskgit, cond_heatmap: torch.Tensor, diffusion: Diffusion): + # 执行完整预览生成流程:点图 -> 热力图 -> 地图。 + fake_heatmap_cond = diffusion.sample(heatmap_model, cond_heatmap) + fake_heatmap_uncond = diffusion.sample(heatmap_model, torch.zeros_like(cond_heatmap)) + fake_heatmap = fake_heatmap_uncond + PREVIEW_CFG_WEIGHT * (fake_heatmap_uncond - fake_heatmap_cond) + return maskgit_generate(maskgit, cond_heatmap.shape[0], fake_heatmap) + + +def save_random_previews(model, maskgit, diffusion, tile_dict): + # 额外生成随机点图预览,便于观察模型的开放式生成效果。 + for preview_idx in range(RANDOM_PREVIEW_COUNT): + cond_array = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W]) + sampling_count = get_nms_sampling_count() + for channel in range(HEATMAP_CHANNEL): + noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H, 0:MAP_W] + cond_array[0, channel] = nms_sampling(noise, sampling_count[channel]) + + generated_map = full_generate(model, maskgit, torch.FloatTensor(cond_array).to(device), diffusion) + generated_img = matrix_to_image_cv(generated_map.view(1, MAP_H, MAP_W)[0].cpu().numpy(), tile_dict) + cv2.imwrite(f"result/joint_img/g-{preview_idx}.png", generated_img) + + +def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict): + # 执行数值验证,并保存生成地图预览图。 model.eval() total_loss = 0.0 total_diffusion_loss = 0.0 total_maskgit_loss = 0.0 with torch.no_grad(): + preview_idx = 0 for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) target_heatmap = batch["target_heatmap"].to(device) target_map = batch["target_map"].to(device) - batch_size = target_heatmap.shape[0] + batch_size, _, map_height, map_width = target_heatmap.shape t = torch.randint(1, T_DIFFUSION, [batch_size], device=device) noise = torch.randn_like(target_heatmap) @@ -159,6 +254,17 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight): total_diffusion_loss += diffusion_loss.item() total_maskgit_loss += maskgit_loss.item() + # 预览生成结果 + generated_map = full_generate(model, maskgit, cond_heatmap, diffusion) + generated_img = matrix_to_image_cv( + generated_map.view(batch_size, map_height, map_width)[0].cpu().numpy(), + tile_dict, + ) + cv2.imwrite(f"result/joint_img/{preview_idx}.png", generated_img) + preview_idx += 1 + + save_random_previews(model, maskgit, diffusion, tile_dict) + size = max(len(dataloader), 1) return { "loss": total_loss / size, @@ -168,9 +274,11 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight): def train(): + # 联合训练 Diffusion,使其同时受到噪声重建和冻结 MaskGIT 的监督。 print(f"Using {device.type} to train model.") args = parse_arguments() + tile_dict = load_tile_dict() maskgit = GinkaMaskGIT( num_classes=NUM_CLASSES, @@ -275,7 +383,7 @@ def train(): checkpoint_path, ) - metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT) + metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT, tile_dict) tqdm.write( f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"E: {epoch + 1} | "