feat: 联合训练预览

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-23 18:37:10 +08:00
parent 90cfe54bd2
commit 164bb24823

View File

@ -4,10 +4,12 @@ import os
import sys import sys
from datetime import datetime from datetime import datetime
import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from perlin_numpy import generate_fractal_noise_2d
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
@ -15,26 +17,39 @@ from .dataset import GinkaJointDataset
from .heatmap.diffusion import Diffusion from .heatmap.diffusion import Diffusion
from .heatmap.model import GinkaHeatmapModel from .heatmap.model import GinkaHeatmapModel
from .maskGIT.model import GinkaMaskGIT from .maskGIT.model import GinkaMaskGIT
from .utils import nms_sampling
from shared.image import matrix_to_image_cv
BATCH_SIZE = 64 # 地图与 token 基础配置
VAL_BATCH_DIVIDER = 64
NUM_CLASSES = 16 NUM_CLASSES = 16
MASK_TOKEN = 15 MASK_TOKEN = 15
GENERATE_STEP = 8
MAP_W = 13 MAP_W = 13
MAP_H = 13 MAP_H = 13
HEATMAP_CHANNEL = 9 HEATMAP_CHANNEL = 9
GENERATE_STEP = 8
# 训练批次与损失配置
BATCH_SIZE = 64
VAL_BATCH_DIVIDER = 64
LABEL_SMOOTHING = 0 LABEL_SMOOTHING = 0
CE_WEIGHT = 0.5 # 联合训练里 MaskGIT 监督项的权重
DROP_RATE = 0.2 # CFG 训练时随机丢弃条件热力图的概率
# MaskGIT 模型结构
NUM_LAYERS = 4 NUM_LAYERS = 4
D_MODEL = 192 D_MODEL = 192
# Diffusion 模型结构与噪声过程
NUM_LAYERS_DIFFUSION = 4 NUM_LAYERS_DIFFUSION = 4
D_MODEL_DIFFUSION = 128 D_MODEL_DIFFUSION = 128
T_DIFFUSION = 100 T_DIFFUSION = 100
MIN_MASK = 0 MIN_MASK = 0
MAX_MASK = 1 MAX_MASK = 1
CE_WEIGHT = 0.5
DROP_RATE = 0.2 # 验证预览配置
PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度
RANDOM_PREVIEW_COUNT = 5 # 每次验证额外生成的随机预览数量
device = torch.device( device = torch.device(
"cuda:1" if torch.cuda.is_available() "cuda:1" if torch.cuda.is_available()
@ -43,11 +58,13 @@ device = torch.device(
) )
os.makedirs("result", exist_ok=True) os.makedirs("result", exist_ok=True)
os.makedirs("result/joint", exist_ok=True) os.makedirs("result/joint", exist_ok=True)
os.makedirs("result/joint_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty() disable_tqdm = not sys.stdout.isatty()
def parse_arguments(): def parse_arguments():
# 解析联合训练脚本的命令行参数。
parser = argparse.ArgumentParser(description="joint training codes") parser = argparse.ArgumentParser(description="joint training codes")
parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth") parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
@ -62,6 +79,7 @@ def parse_arguments():
def load_heatmap_checkpoint(model, optimizer, args): def load_heatmap_checkpoint(model, optimizer, args):
# 加载预训练 Diffusion 权重,并在需要时恢复优化器状态。
if not args.state_heatmap: if not args.state_heatmap:
return return
@ -78,12 +96,14 @@ def load_heatmap_checkpoint(model, optimizer, args):
def freeze_module(module: torch.nn.Module): def freeze_module(module: torch.nn.Module):
# 冻结模块参数,使其在联合训练中只作为固定监督器使用。
module.eval() module.eval()
for parameter in module.parameters(): for parameter in module.parameters():
parameter.requires_grad = False parameter.requires_grad = False
def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor): def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor):
# 根据当前时刻的噪声预测还原 x0 热力图估计。
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None] sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None] sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
@ -91,6 +111,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
# 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。
batch_size, height, width = target_map.shape batch_size, height, width = target_map.shape
target_tokens = target_map.view(batch_size, height * width) target_tokens = target_map.view(batch_size, height * width)
canvas = torch.full_like(target_tokens, MASK_TOKEN) canvas = torch.full_like(target_tokens, MASK_TOKEN)
@ -101,6 +122,7 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
if current_mask.sum().item() == 0: if current_mask.sum().item() == 0:
break break
# 保证前向传播可导
logits = maskgit(canvas, generated_heatmap) logits = maskgit(canvas, generated_heatmap)
ce = F.cross_entropy( ce = F.cross_entropy(
logits.permute(0, 2, 1), logits.permute(0, 2, 1),
@ -131,18 +153,91 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
return torch.stack(losses).mean() return torch.stack(losses).mean()
def validate(model, maskgit, diffusion, dataloader, ce_weight): def load_tile_dict():
# 加载用于可视化地图的图块贴图。
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
return tile_dict
def get_nms_sampling_count():
# 为随机点图预览采样每个通道的点数量。
return [
np.random.randint(20, 40),
np.random.randint(10, 20),
np.random.randint(10, 30),
np.random.randint(4, 12),
np.random.randint(4, 12),
np.random.randint(2, 6),
np.random.randint(0, 2),
np.random.randint(1, 3),
np.random.randint(2, 10)
]
def maskgit_generate(maskgit, batch_size: int, heatmap: torch.Tensor):
# 使用冻结的 MaskGIT 把热力图解码为完整地图。
generated_map = torch.full((batch_size, MAP_H * MAP_W), MASK_TOKEN, device=device)
for step in range(GENERATE_STEP):
logits = maskgit(generated_map, heatmap)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled_tiles = dist.sample()
confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1)
ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_H * MAP_W)
if num_to_mask > 0:
_, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False)
sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN)
generated_map = sampled_tiles
if (generated_map == MASK_TOKEN).sum() == 0:
break
return generated_map
def full_generate(heatmap_model, maskgit, cond_heatmap: torch.Tensor, diffusion: Diffusion):
# 执行完整预览生成流程:点图 -> 热力图 -> 地图。
fake_heatmap_cond = diffusion.sample(heatmap_model, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap_model, torch.zeros_like(cond_heatmap))
fake_heatmap = fake_heatmap_uncond + PREVIEW_CFG_WEIGHT * (fake_heatmap_uncond - fake_heatmap_cond)
return maskgit_generate(maskgit, cond_heatmap.shape[0], fake_heatmap)
def save_random_previews(model, maskgit, diffusion, tile_dict):
# 额外生成随机点图预览,便于观察模型的开放式生成效果。
for preview_idx in range(RANDOM_PREVIEW_COUNT):
cond_array = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W])
sampling_count = get_nms_sampling_count()
for channel in range(HEATMAP_CHANNEL):
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H, 0:MAP_W]
cond_array[0, channel] = nms_sampling(noise, sampling_count[channel])
generated_map = full_generate(model, maskgit, torch.FloatTensor(cond_array).to(device), diffusion)
generated_img = matrix_to_image_cv(generated_map.view(1, MAP_H, MAP_W)[0].cpu().numpy(), tile_dict)
cv2.imwrite(f"result/joint_img/g-{preview_idx}.png", generated_img)
def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict):
# 执行数值验证,并保存生成地图预览图。
model.eval() model.eval()
total_loss = 0.0 total_loss = 0.0
total_diffusion_loss = 0.0 total_diffusion_loss = 0.0
total_maskgit_loss = 0.0 total_maskgit_loss = 0.0
with torch.no_grad(): with torch.no_grad():
preview_idx = 0
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm): for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device) cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) target_heatmap = batch["target_heatmap"].to(device)
target_map = batch["target_map"].to(device) target_map = batch["target_map"].to(device)
batch_size = target_heatmap.shape[0] batch_size, _, map_height, map_width = target_heatmap.shape
t = torch.randint(1, T_DIFFUSION, [batch_size], device=device) t = torch.randint(1, T_DIFFUSION, [batch_size], device=device)
noise = torch.randn_like(target_heatmap) noise = torch.randn_like(target_heatmap)
@ -159,6 +254,17 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
total_diffusion_loss += diffusion_loss.item() total_diffusion_loss += diffusion_loss.item()
total_maskgit_loss += maskgit_loss.item() total_maskgit_loss += maskgit_loss.item()
# 预览生成结果
generated_map = full_generate(model, maskgit, cond_heatmap, diffusion)
generated_img = matrix_to_image_cv(
generated_map.view(batch_size, map_height, map_width)[0].cpu().numpy(),
tile_dict,
)
cv2.imwrite(f"result/joint_img/{preview_idx}.png", generated_img)
preview_idx += 1
save_random_previews(model, maskgit, diffusion, tile_dict)
size = max(len(dataloader), 1) size = max(len(dataloader), 1)
return { return {
"loss": total_loss / size, "loss": total_loss / size,
@ -168,9 +274,11 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
def train(): def train():
# 联合训练 Diffusion使其同时受到噪声重建和冻结 MaskGIT 的监督。
print(f"Using {device.type} to train model.") print(f"Using {device.type} to train model.")
args = parse_arguments() args = parse_arguments()
tile_dict = load_tile_dict()
maskgit = GinkaMaskGIT( maskgit = GinkaMaskGIT(
num_classes=NUM_CLASSES, num_classes=NUM_CLASSES,
@ -275,7 +383,7 @@ def train():
checkpoint_path, checkpoint_path,
) )
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT) metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT, tile_dict)
tqdm.write( tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | " f"E: {epoch + 1} | "