mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 联合训练预览
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
90cfe54bd2
commit
164bb24823
@ -4,10 +4,12 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from perlin_numpy import generate_fractal_noise_2d
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -15,26 +17,39 @@ from .dataset import GinkaJointDataset
|
|||||||
from .heatmap.diffusion import Diffusion
|
from .heatmap.diffusion import Diffusion
|
||||||
from .heatmap.model import GinkaHeatmapModel
|
from .heatmap.model import GinkaHeatmapModel
|
||||||
from .maskGIT.model import GinkaMaskGIT
|
from .maskGIT.model import GinkaMaskGIT
|
||||||
|
from .utils import nms_sampling
|
||||||
|
from shared.image import matrix_to_image_cv
|
||||||
|
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
# 地图与 token 基础配置
|
||||||
VAL_BATCH_DIVIDER = 64
|
|
||||||
NUM_CLASSES = 16
|
NUM_CLASSES = 16
|
||||||
MASK_TOKEN = 15
|
MASK_TOKEN = 15
|
||||||
GENERATE_STEP = 8
|
|
||||||
MAP_W = 13
|
MAP_W = 13
|
||||||
MAP_H = 13
|
MAP_H = 13
|
||||||
HEATMAP_CHANNEL = 9
|
HEATMAP_CHANNEL = 9
|
||||||
|
GENERATE_STEP = 8
|
||||||
|
|
||||||
|
# 训练批次与损失配置
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
VAL_BATCH_DIVIDER = 64
|
||||||
LABEL_SMOOTHING = 0
|
LABEL_SMOOTHING = 0
|
||||||
|
CE_WEIGHT = 0.5 # 联合训练里 MaskGIT 监督项的权重
|
||||||
|
DROP_RATE = 0.2 # CFG 训练时随机丢弃条件热力图的概率
|
||||||
|
|
||||||
|
# MaskGIT 模型结构
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
D_MODEL = 192
|
D_MODEL = 192
|
||||||
|
|
||||||
|
# Diffusion 模型结构与噪声过程
|
||||||
NUM_LAYERS_DIFFUSION = 4
|
NUM_LAYERS_DIFFUSION = 4
|
||||||
D_MODEL_DIFFUSION = 128
|
D_MODEL_DIFFUSION = 128
|
||||||
T_DIFFUSION = 100
|
T_DIFFUSION = 100
|
||||||
MIN_MASK = 0
|
MIN_MASK = 0
|
||||||
MAX_MASK = 1
|
MAX_MASK = 1
|
||||||
CE_WEIGHT = 0.5
|
|
||||||
DROP_RATE = 0.2
|
# 验证预览配置
|
||||||
|
PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度
|
||||||
|
RANDOM_PREVIEW_COUNT = 5 # 每次验证额外生成的随机预览数量
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -43,11 +58,13 @@ device = torch.device(
|
|||||||
)
|
)
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
os.makedirs("result/joint", exist_ok=True)
|
os.makedirs("result/joint", exist_ok=True)
|
||||||
|
os.makedirs("result/joint_img", exist_ok=True)
|
||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
disable_tqdm = not sys.stdout.isatty()
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
|
# 解析联合训练脚本的命令行参数。
|
||||||
parser = argparse.ArgumentParser(description="joint training codes")
|
parser = argparse.ArgumentParser(description="joint training codes")
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
|
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
|
||||||
@ -62,6 +79,7 @@ def parse_arguments():
|
|||||||
|
|
||||||
|
|
||||||
def load_heatmap_checkpoint(model, optimizer, args):
|
def load_heatmap_checkpoint(model, optimizer, args):
|
||||||
|
# 加载预训练 Diffusion 权重,并在需要时恢复优化器状态。
|
||||||
if not args.state_heatmap:
|
if not args.state_heatmap:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -78,12 +96,14 @@ def load_heatmap_checkpoint(model, optimizer, args):
|
|||||||
|
|
||||||
|
|
||||||
def freeze_module(module: torch.nn.Module):
|
def freeze_module(module: torch.nn.Module):
|
||||||
|
# 冻结模块参数,使其在联合训练中只作为固定监督器使用。
|
||||||
module.eval()
|
module.eval()
|
||||||
for parameter in module.parameters():
|
for parameter in module.parameters():
|
||||||
parameter.requires_grad = False
|
parameter.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor):
|
def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor):
|
||||||
|
# 根据当前时刻的噪声预测还原 x0 热力图估计。
|
||||||
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
|
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
|
||||||
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
|
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
|
||||||
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
|
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
|
||||||
@ -91,6 +111,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
|
|||||||
|
|
||||||
|
|
||||||
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
|
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
|
||||||
|
# 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。
|
||||||
batch_size, height, width = target_map.shape
|
batch_size, height, width = target_map.shape
|
||||||
target_tokens = target_map.view(batch_size, height * width)
|
target_tokens = target_map.view(batch_size, height * width)
|
||||||
canvas = torch.full_like(target_tokens, MASK_TOKEN)
|
canvas = torch.full_like(target_tokens, MASK_TOKEN)
|
||||||
@ -101,6 +122,7 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
|
|||||||
if current_mask.sum().item() == 0:
|
if current_mask.sum().item() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 保证前向传播可导
|
||||||
logits = maskgit(canvas, generated_heatmap)
|
logits = maskgit(canvas, generated_heatmap)
|
||||||
ce = F.cross_entropy(
|
ce = F.cross_entropy(
|
||||||
logits.permute(0, 2, 1),
|
logits.permute(0, 2, 1),
|
||||||
@ -131,18 +153,91 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
|
|||||||
return torch.stack(losses).mean()
|
return torch.stack(losses).mean()
|
||||||
|
|
||||||
|
|
||||||
def validate(model, maskgit, diffusion, dataloader, ce_weight):
|
def load_tile_dict():
|
||||||
|
# 加载用于可视化地图的图块贴图。
|
||||||
|
tile_dict = dict()
|
||||||
|
for file in os.listdir('tiles'):
|
||||||
|
name = os.path.splitext(file)[0]
|
||||||
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
return tile_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_nms_sampling_count():
|
||||||
|
# 为随机点图预览采样每个通道的点数量。
|
||||||
|
return [
|
||||||
|
np.random.randint(20, 40),
|
||||||
|
np.random.randint(10, 20),
|
||||||
|
np.random.randint(10, 30),
|
||||||
|
np.random.randint(4, 12),
|
||||||
|
np.random.randint(4, 12),
|
||||||
|
np.random.randint(2, 6),
|
||||||
|
np.random.randint(0, 2),
|
||||||
|
np.random.randint(1, 3),
|
||||||
|
np.random.randint(2, 10)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def maskgit_generate(maskgit, batch_size: int, heatmap: torch.Tensor):
|
||||||
|
# 使用冻结的 MaskGIT 把热力图解码为完整地图。
|
||||||
|
generated_map = torch.full((batch_size, MAP_H * MAP_W), MASK_TOKEN, device=device)
|
||||||
|
for step in range(GENERATE_STEP):
|
||||||
|
logits = maskgit(generated_map, heatmap)
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
dist = torch.distributions.Categorical(probs)
|
||||||
|
sampled_tiles = dist.sample()
|
||||||
|
confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1)
|
||||||
|
|
||||||
|
ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2)
|
||||||
|
num_to_mask = math.floor(ratio * MAP_H * MAP_W)
|
||||||
|
|
||||||
|
if num_to_mask > 0:
|
||||||
|
_, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False)
|
||||||
|
sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN)
|
||||||
|
|
||||||
|
generated_map = sampled_tiles
|
||||||
|
if (generated_map == MASK_TOKEN).sum() == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return generated_map
|
||||||
|
|
||||||
|
|
||||||
|
def full_generate(heatmap_model, maskgit, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
||||||
|
# 执行完整预览生成流程:点图 -> 热力图 -> 地图。
|
||||||
|
fake_heatmap_cond = diffusion.sample(heatmap_model, cond_heatmap)
|
||||||
|
fake_heatmap_uncond = diffusion.sample(heatmap_model, torch.zeros_like(cond_heatmap))
|
||||||
|
fake_heatmap = fake_heatmap_uncond + PREVIEW_CFG_WEIGHT * (fake_heatmap_uncond - fake_heatmap_cond)
|
||||||
|
return maskgit_generate(maskgit, cond_heatmap.shape[0], fake_heatmap)
|
||||||
|
|
||||||
|
|
||||||
|
def save_random_previews(model, maskgit, diffusion, tile_dict):
|
||||||
|
# 额外生成随机点图预览,便于观察模型的开放式生成效果。
|
||||||
|
for preview_idx in range(RANDOM_PREVIEW_COUNT):
|
||||||
|
cond_array = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W])
|
||||||
|
sampling_count = get_nms_sampling_count()
|
||||||
|
for channel in range(HEATMAP_CHANNEL):
|
||||||
|
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H, 0:MAP_W]
|
||||||
|
cond_array[0, channel] = nms_sampling(noise, sampling_count[channel])
|
||||||
|
|
||||||
|
generated_map = full_generate(model, maskgit, torch.FloatTensor(cond_array).to(device), diffusion)
|
||||||
|
generated_img = matrix_to_image_cv(generated_map.view(1, MAP_H, MAP_W)[0].cpu().numpy(), tile_dict)
|
||||||
|
cv2.imwrite(f"result/joint_img/g-{preview_idx}.png", generated_img)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict):
|
||||||
|
# 执行数值验证,并保存生成地图预览图。
|
||||||
model.eval()
|
model.eval()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
total_diffusion_loss = 0.0
|
total_diffusion_loss = 0.0
|
||||||
total_maskgit_loss = 0.0
|
total_maskgit_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
preview_idx = 0
|
||||||
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
|
||||||
cond_heatmap = batch["cond_heatmap"].to(device)
|
cond_heatmap = batch["cond_heatmap"].to(device)
|
||||||
target_heatmap = batch["target_heatmap"].to(device)
|
target_heatmap = batch["target_heatmap"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
batch_size = target_heatmap.shape[0]
|
batch_size, _, map_height, map_width = target_heatmap.shape
|
||||||
|
|
||||||
t = torch.randint(1, T_DIFFUSION, [batch_size], device=device)
|
t = torch.randint(1, T_DIFFUSION, [batch_size], device=device)
|
||||||
noise = torch.randn_like(target_heatmap)
|
noise = torch.randn_like(target_heatmap)
|
||||||
@ -159,6 +254,17 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
|
|||||||
total_diffusion_loss += diffusion_loss.item()
|
total_diffusion_loss += diffusion_loss.item()
|
||||||
total_maskgit_loss += maskgit_loss.item()
|
total_maskgit_loss += maskgit_loss.item()
|
||||||
|
|
||||||
|
# 预览生成结果
|
||||||
|
generated_map = full_generate(model, maskgit, cond_heatmap, diffusion)
|
||||||
|
generated_img = matrix_to_image_cv(
|
||||||
|
generated_map.view(batch_size, map_height, map_width)[0].cpu().numpy(),
|
||||||
|
tile_dict,
|
||||||
|
)
|
||||||
|
cv2.imwrite(f"result/joint_img/{preview_idx}.png", generated_img)
|
||||||
|
preview_idx += 1
|
||||||
|
|
||||||
|
save_random_previews(model, maskgit, diffusion, tile_dict)
|
||||||
|
|
||||||
size = max(len(dataloader), 1)
|
size = max(len(dataloader), 1)
|
||||||
return {
|
return {
|
||||||
"loss": total_loss / size,
|
"loss": total_loss / size,
|
||||||
@ -168,9 +274,11 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight):
|
|||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
# 联合训练 Diffusion,使其同时受到噪声重建和冻结 MaskGIT 的监督。
|
||||||
print(f"Using {device.type} to train model.")
|
print(f"Using {device.type} to train model.")
|
||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
tile_dict = load_tile_dict()
|
||||||
|
|
||||||
maskgit = GinkaMaskGIT(
|
maskgit = GinkaMaskGIT(
|
||||||
num_classes=NUM_CLASSES,
|
num_classes=NUM_CLASSES,
|
||||||
@ -275,7 +383,7 @@ def train():
|
|||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT)
|
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT, tile_dict)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||||
f"E: {epoch + 1} | "
|
f"E: {epoch + 1} | "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user