mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 联合训练预览
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
90cfe54bd2
commit
164bb24823
@ -4,10 +4,12 @@ import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from perlin_numpy import generate_fractal_noise_2d
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -15,26 +17,39 @@ from .dataset import GinkaJointDataset
|
||||
from .heatmap.diffusion import Diffusion
|
||||
from .heatmap.model import GinkaHeatmapModel
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
from .utils import nms_sampling
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
|
||||
BATCH_SIZE = 64
|
||||
VAL_BATCH_DIVIDER = 64
|
||||
# 地图与 token 基础配置
|
||||
NUM_CLASSES = 16
|
||||
MASK_TOKEN = 15
|
||||
GENERATE_STEP = 8
|
||||
MAP_W = 13
|
||||
MAP_H = 13
|
||||
HEATMAP_CHANNEL = 9
|
||||
GENERATE_STEP = 8
|
||||
|
||||
# 训练批次与损失配置
|
||||
BATCH_SIZE = 64
|
||||
VAL_BATCH_DIVIDER = 64
|
||||
LABEL_SMOOTHING = 0
|
||||
CE_WEIGHT = 0.5 # 联合训练里 MaskGIT 监督项的权重
|
||||
DROP_RATE = 0.2 # CFG 训练时随机丢弃条件热力图的概率
|
||||
|
||||
# MaskGIT 模型结构
|
||||
NUM_LAYERS = 4
|
||||
D_MODEL = 192
|
||||
|
||||
# Diffusion 模型结构与噪声过程
|
||||
NUM_LAYERS_DIFFUSION = 4
|
||||
D_MODEL_DIFFUSION = 128
|
||||
T_DIFFUSION = 100
|
||||
MIN_MASK = 0
|
||||
MAX_MASK = 1
|
||||
CE_WEIGHT = 0.5
|
||||
DROP_RATE = 0.2
|
||||
|
||||
# 验证预览配置
|
||||
PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度
|
||||
RANDOM_PREVIEW_COUNT = 5 # 每次验证额外生成的随机预览数量
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -43,11 +58,13 @@ device = torch.device(
|
||||
)
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/joint", exist_ok=True)
|
||||
os.makedirs("result/joint_img", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
# 解析联合训练脚本的命令行参数。
|
||||
parser = argparse.ArgumentParser(description="joint training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
|
||||
@ -62,6 +79,7 @@ def parse_arguments():
|
||||
|
||||
|
||||
def load_heatmap_checkpoint(model, optimizer, args):
|
||||
# 加载预训练 Diffusion 权重,并在需要时恢复优化器状态。
|
||||
if not args.state_heatmap:
|
||||
return
|
||||
|
||||
@ -78,12 +96,14 @@ def load_heatmap_checkpoint(model, optimizer, args):
|
||||
|
||||
|
||||
def freeze_module(module: torch.nn.Module):
|
||||
# 冻结模块参数,使其在联合训练中只作为固定监督器使用。
|
||||
module.eval()
|
||||
for parameter in module.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
|
||||
def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor):
|
||||
# 根据当前时刻的噪声预测还原 x0 热力图估计。
|
||||
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
|
||||
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
|
||||
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
|
||||
@ -91,6 +111,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
|
||||
|
||||
|
||||
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
|
||||
# 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。
|
||||
batch_size, height, width = target_map.shape
|
||||
target_tokens = target_map.view(batch_size, height * width)
|
||||
canvas = torch.full_like(target_tokens, MASK_TOKEN)
|
||||
@ -101,6 +122,7 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
|
||||
if current_mask.sum().item() == 0:
|
||||
break
|
||||
|
||||
# 保证前向传播可导
|
||||
logits = maskgit(canvas, generated_heatmap)
|
||||
ce = F.cross_entropy(
|
||||
logits.permute(0, 2, 1),
|
||||
@ -131,18 +153,91 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def validate(model, maskgit, diffusion, dataloader, ce_weight):
|
||||
def load_tile_dict():
|
||||
# 加载用于可视化地图的图块贴图。
|
||||
tile_dict = dict()
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
return tile_dict
|
||||
|
||||
|
||||
def get_nms_sampling_count():
|
||||
# 为随机点图预览采样每个通道的点数量。
|
||||
return [
|
||||
np.random.randint(20, 40),
|
||||
np.random.randint(10, 20),
|
||||
np.random.randint(10, 30),
|
||||
np.random.randint(4, 12),
|
||||
np.random.randint(4, 12),
|
||||
np.random.randint(2, 6),
|
||||
np.random.randint(0, 2),
|
||||
np.random.randint(1, 3),
|
||||
np.random.randint(2, 10)
|
||||
]
|
||||
|
||||
|
||||
def maskgit_generate(maskgit, batch_size: int, heatmap: torch.Tensor):
|
||||
# 使用冻结的 MaskGIT 把热力图解码为完整地图。
|
||||
generated_map = torch.full((batch_size, MAP_H * MAP_W), MASK_TOKEN, device=device)
|
||||
for step in range(GENERATE_STEP):
|
||||
logits = maskgit(generated_map, heatmap)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled_tiles = dist.sample()
|
||||
confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2)
|
||||
num_to_mask = math.floor(ratio * MAP_H * MAP_W)
|
||||
|
||||
if num_to_mask > 0:
|
||||
_, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False)
|
||||
sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN)
|
||||
|
||||
generated_map = sampled_tiles
|
||||
if (generated_map == MASK_TOKEN).sum() == 0:
|
||||
break
|
||||
|
||||
return generated_map
|
||||
|
||||
|
||||
def full_generate(heatmap_model, maskgit, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
||||
# 执行完整预览生成流程:点图 -> 热力图 -> 地图。
|
||||
fake_heatmap_cond = diffusion.sample(heatmap_model, cond_heatmap)
|
||||
fake_heatmap_uncond = diffusion.sample(heatmap_model, torch.zeros_like(cond_heatmap))
|
||||
fake_heatmap = fake_heatmap_uncond + PREVIEW_CFG_WEIGHT * (fake_heatmap_uncond - fake_heatmap_cond)
|
||||
return maskgit_generate(maskgit, cond_heatmap.shape[0], fake_heatmap)
|
||||
|
||||
|
||||
def save_random_previews(model, maskgit, diffusion, tile_dict):
|
||||
# 额外生成随机点图预览,便于观察模型的开放式生成效果。
|
||||
for preview_idx in range(RANDOM_PREVIEW_COUNT):
|
||||
cond_array = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W])
|
||||
sampling_count = get_nms_sampling_count()
|
||||
for channel in range(HEATMAP_CHANNEL):
|
||||
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H, 0:MAP_W]
|
||||
cond_array[0, channel] = nms_sampling(noise, sampling_count[channel])
|
||||
|
||||
generated_map = full_generate(model, maskgit, torch.FloatTensor(cond_array).to(device), diffusion)
|
||||
generated_img = matrix_to_image_cv(generated_map.view(1, MAP_H, MAP_W)[0].cpu().numpy(), tile_dict)
|
||||
cv2.imwrite(f"result/joint_img/g-{preview_idx}.png", generated_img)
|
||||
|
||||
|
||||
def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict):
|
||||
# 执行数值验证,并保存生成地图预览图。
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
total_diffusion_loss = 0.0
|
||||
total_maskgit_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
preview_idx = 0
|
||||
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
cond_heatmap = batch["cond_heatmap"].to(device)
|
||||
target_heatmap = batch["target_heatmap"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
batch_size = target_heatmap.shape[0]
|
||||
batch_size, _, map_height, map_width = target_heatmap.shape
|
||||
|
||||
t = torch.randint(1, T_DIFFUSION, [batch_size], device=device)
|
||||
noise = torch.randn_like(target_heatmap)
|
||||
@ -159,6 +254,17 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
|
||||
total_diffusion_loss += diffusion_loss.item()
|
||||
total_maskgit_loss += maskgit_loss.item()
|
||||
|
||||
# 预览生成结果
|
||||
generated_map = full_generate(model, maskgit, cond_heatmap, diffusion)
|
||||
generated_img = matrix_to_image_cv(
|
||||
generated_map.view(batch_size, map_height, map_width)[0].cpu().numpy(),
|
||||
tile_dict,
|
||||
)
|
||||
cv2.imwrite(f"result/joint_img/{preview_idx}.png", generated_img)
|
||||
preview_idx += 1
|
||||
|
||||
save_random_previews(model, maskgit, diffusion, tile_dict)
|
||||
|
||||
size = max(len(dataloader), 1)
|
||||
return {
|
||||
"loss": total_loss / size,
|
||||
@ -168,9 +274,11 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
|
||||
|
||||
|
||||
def train():
|
||||
# 联合训练 Diffusion,使其同时受到噪声重建和冻结 MaskGIT 的监督。
|
||||
print(f"Using {device.type} to train model.")
|
||||
|
||||
args = parse_arguments()
|
||||
tile_dict = load_tile_dict()
|
||||
|
||||
maskgit = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
@ -275,7 +383,7 @@ def train():
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT)
|
||||
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT, tile_dict)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"E: {epoch + 1} | "
|
||||
|
||||
Loading…
Reference in New Issue
Block a user