feat: 联合训练预览

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-23 18:37:10 +08:00
parent 90cfe54bd2
commit 164bb24823

View File

@ -4,10 +4,12 @@ import os
import sys
from datetime import datetime
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from perlin_numpy import generate_fractal_noise_2d
from torch.utils.data import DataLoader
from tqdm import tqdm
@ -15,26 +17,39 @@ from .dataset import GinkaJointDataset
from .heatmap.diffusion import Diffusion
from .heatmap.model import GinkaHeatmapModel
from .maskGIT.model import GinkaMaskGIT
from .utils import nms_sampling
from shared.image import matrix_to_image_cv
BATCH_SIZE = 64
VAL_BATCH_DIVIDER = 64
# 地图与 token 基础配置
NUM_CLASSES = 16
MASK_TOKEN = 15
GENERATE_STEP = 8
MAP_W = 13
MAP_H = 13
HEATMAP_CHANNEL = 9
GENERATE_STEP = 8
# 训练批次与损失配置
BATCH_SIZE = 64
VAL_BATCH_DIVIDER = 64
LABEL_SMOOTHING = 0
CE_WEIGHT = 0.5 # 联合训练里 MaskGIT 监督项的权重
DROP_RATE = 0.2 # CFG 训练时随机丢弃条件热力图的概率
# MaskGIT 模型结构
NUM_LAYERS = 4
D_MODEL = 192
# Diffusion 模型结构与噪声过程
NUM_LAYERS_DIFFUSION = 4
D_MODEL_DIFFUSION = 128
T_DIFFUSION = 100
MIN_MASK = 0
MAX_MASK = 1
CE_WEIGHT = 0.5
DROP_RATE = 0.2
# 验证预览配置
PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度
RANDOM_PREVIEW_COUNT = 5 # 每次验证额外生成的随机预览数量
device = torch.device(
"cuda:1" if torch.cuda.is_available()
@ -43,11 +58,13 @@ device = torch.device(
)
os.makedirs("result", exist_ok=True)
os.makedirs("result/joint", exist_ok=True)
os.makedirs("result/joint_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
# 解析联合训练脚本的命令行参数。
parser = argparse.ArgumentParser(description="joint training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
@ -62,6 +79,7 @@ def parse_arguments():
def load_heatmap_checkpoint(model, optimizer, args):
# 加载预训练 Diffusion 权重,并在需要时恢复优化器状态。
if not args.state_heatmap:
return
@ -78,12 +96,14 @@ def load_heatmap_checkpoint(model, optimizer, args):
def freeze_module(module: torch.nn.Module):
# 冻结模块参数,使其在联合训练中只作为固定监督器使用。
module.eval()
for parameter in module.parameters():
parameter.requires_grad = False
def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor):
# 根据当前时刻的噪声预测还原 x0 热力图估计。
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
@ -91,6 +111,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
# 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。
batch_size, height, width = target_map.shape
target_tokens = target_map.view(batch_size, height * width)
canvas = torch.full_like(target_tokens, MASK_TOKEN)
@ -101,6 +122,7 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
if current_mask.sum().item() == 0:
break
# 保证前向传播可导
logits = maskgit(canvas, generated_heatmap)
ce = F.cross_entropy(
logits.permute(0, 2, 1),
@ -131,18 +153,91 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
return torch.stack(losses).mean()
def validate(model, maskgit, diffusion, dataloader, ce_weight):
def load_tile_dict():
# 加载用于可视化地图的图块贴图。
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
return tile_dict
def get_nms_sampling_count():
# 为随机点图预览采样每个通道的点数量。
return [
np.random.randint(20, 40),
np.random.randint(10, 20),
np.random.randint(10, 30),
np.random.randint(4, 12),
np.random.randint(4, 12),
np.random.randint(2, 6),
np.random.randint(0, 2),
np.random.randint(1, 3),
np.random.randint(2, 10)
]
def maskgit_generate(maskgit, batch_size: int, heatmap: torch.Tensor):
# 使用冻结的 MaskGIT 把热力图解码为完整地图。
generated_map = torch.full((batch_size, MAP_H * MAP_W), MASK_TOKEN, device=device)
for step in range(GENERATE_STEP):
logits = maskgit(generated_map, heatmap)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled_tiles = dist.sample()
confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1)
ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_H * MAP_W)
if num_to_mask > 0:
_, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False)
sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN)
generated_map = sampled_tiles
if (generated_map == MASK_TOKEN).sum() == 0:
break
return generated_map
def full_generate(heatmap_model, maskgit, cond_heatmap: torch.Tensor, diffusion: Diffusion):
# 执行完整预览生成流程:点图 -> 热力图 -> 地图。
fake_heatmap_cond = diffusion.sample(heatmap_model, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap_model, torch.zeros_like(cond_heatmap))
fake_heatmap = fake_heatmap_uncond + PREVIEW_CFG_WEIGHT * (fake_heatmap_uncond - fake_heatmap_cond)
return maskgit_generate(maskgit, cond_heatmap.shape[0], fake_heatmap)
def save_random_previews(model, maskgit, diffusion, tile_dict):
# 额外生成随机点图预览,便于观察模型的开放式生成效果。
for preview_idx in range(RANDOM_PREVIEW_COUNT):
cond_array = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W])
sampling_count = get_nms_sampling_count()
for channel in range(HEATMAP_CHANNEL):
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H, 0:MAP_W]
cond_array[0, channel] = nms_sampling(noise, sampling_count[channel])
generated_map = full_generate(model, maskgit, torch.FloatTensor(cond_array).to(device), diffusion)
generated_img = matrix_to_image_cv(generated_map.view(1, MAP_H, MAP_W)[0].cpu().numpy(), tile_dict)
cv2.imwrite(f"result/joint_img/g-{preview_idx}.png", generated_img)
def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict):
# 执行数值验证,并保存生成地图预览图。
model.eval()
total_loss = 0.0
total_diffusion_loss = 0.0
total_maskgit_loss = 0.0
with torch.no_grad():
preview_idx = 0
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device)
target_map = batch["target_map"].to(device)
batch_size = target_heatmap.shape[0]
batch_size, _, map_height, map_width = target_heatmap.shape
t = torch.randint(1, T_DIFFUSION, [batch_size], device=device)
noise = torch.randn_like(target_heatmap)
@ -159,6 +254,17 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
total_diffusion_loss += diffusion_loss.item()
total_maskgit_loss += maskgit_loss.item()
# 预览生成结果
generated_map = full_generate(model, maskgit, cond_heatmap, diffusion)
generated_img = matrix_to_image_cv(
generated_map.view(batch_size, map_height, map_width)[0].cpu().numpy(),
tile_dict,
)
cv2.imwrite(f"result/joint_img/{preview_idx}.png", generated_img)
preview_idx += 1
save_random_previews(model, maskgit, diffusion, tile_dict)
size = max(len(dataloader), 1)
return {
"loss": total_loss / size,
@ -168,9 +274,11 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
def train():
# 联合训练 Diffusion使其同时受到噪声重建和冻结 MaskGIT 的监督。
print(f"Using {device.type} to train model.")
args = parse_arguments()
tile_dict = load_tile_dict()
maskgit = GinkaMaskGIT(
num_classes=NUM_CLASSES,
@ -275,7 +383,7 @@ def train():
checkpoint_path,
)
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT)
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT, tile_dict)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | "