diff --git a/data/src/auto.ts b/data/src/auto.ts index 515fe7c..5a055b4 100644 --- a/data/src/auto.ts +++ b/data/src/auto.ts @@ -245,13 +245,13 @@ const labelConfig: IAutoLabelConfig = { commonDoors: [2], specialDoors: [2, 2], keys: [3], - redGems: [4], - blueGems: [5], - greenGems: [6], - potions: [7], - items: [8], - enemies: [9], - entry: 10 + redGems: [3], + blueGems: [3], + greenGems: [3], + potions: [3], + items: [3], + enemies: [4], + entry: 5 }, allowedSize: [[13, 13]], allowUselessBranch: false, @@ -333,17 +333,17 @@ const labelConfig: IAutoLabelConfig = { const data: GinkaTrainData = { map: floor.data.map, size: [width, height], - heatmap: [ - normalizeHeatmap(info.wallHeatmap), - normalizeHeatmap(info.enemyHeatmap), - normalizeHeatmap(info.resourceHeatmap), - normalizeHeatmap(info.potionHeatmap), - normalizeHeatmap(info.gemHeatmap), - normalizeHeatmap(info.keyHeatmap), - normalizeHeatmap(info.itemHeatmap), - normalizeHeatmap(info.entryHeatmap), - normalizeHeatmap(info.doorHeatmap) - ], + // heatmap: [ + // normalizeHeatmap(info.wallHeatmap), + // normalizeHeatmap(info.enemyHeatmap), + // normalizeHeatmap(info.resourceHeatmap), + // normalizeHeatmap(info.potionHeatmap), + // normalizeHeatmap(info.gemHeatmap), + // normalizeHeatmap(info.keyHeatmap), + // normalizeHeatmap(info.itemHeatmap), + // normalizeHeatmap(info.entryHeatmap), + // normalizeHeatmap(info.doorHeatmap) + // ], val: [ info.globalDensity, info.wallDensity, diff --git a/data/src/shared.ts b/data/src/shared.ts index da6ca03..054420a 100644 --- a/data/src/shared.ts +++ b/data/src/shared.ts @@ -1,17 +1,18 @@ // 基本图块定义 +// 新方案 ID:0=空地 1=墙壁 2=门 3=资源(all) 4=怪物 5=入口 6=掩码 export const emptyTiles = new Set([0]); export const wallTiles = new Set([1]); export const decorationTiles = new Set([16]); export const commonDoorTiles = new Set([2]); export const specialDoorTiles = new Set([2]); export const keyTiles = new Set([3]); -export const redGemTiles = new Set([4]); -export const blueGemTiles = new Set([5]); -export const greenGemTiles = new Set([6]); -export const potionTiles = new Set([7]); -export const itemTiles = new Set([8]); -export const enemyTiles = new Set([9]); -export const entryTiles = new Set([10]); +export const redGemTiles = new Set([3]); +export const blueGemTiles = new Set([3]); +export const greenGemTiles = new Set([3]); +export const potionTiles = new Set([3]); +export const itemTiles = new Set([3]); +export const enemyTiles = new Set([4]); +export const entryTiles = new Set([5]); // 组合图块定义 export const doorTiles = commonDoorTiles.union(specialDoorTiles); diff --git a/ginka/dataset.py b/ginka/dataset.py index 9b23ca8..e2d9c12 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -1,7 +1,6 @@ import json import random import torch -import cv2 import numpy as np from torch.utils.data import Dataset @@ -15,231 +14,6 @@ def load_data(path: str): return data_list -# 资源类别压缩:将所有资源 tile(钥匙/红宝石/蓝宝石/绿宝石/血瓶/道具)统一映射为 3 -# 其余 tile 保持原始编号(enemy=9, entry=10, mask=15) -_RESOURCE_REMAP = np.array([0, 1, 2, 3, 3, 3, 3, 3, 3, 9, 10, 11, 12, 13, 14, 15], dtype=np.int64) - -def remap_resources(arr: np.ndarray) -> np.ndarray: - """将地图 numpy 数组中的资源 tile (3~8) 统一压缩为 3。""" - return _RESOURCE_REMAP[arr] - -class GinkaMaskGITDataset(Dataset): - def __init__( - self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6, - noise_prob=0.2, drop_prob=0.2, noise_sigma=0.1 - ): - self.data = load_data(data_path) - self.sigma_rand = sigma_rand - self.blur_min = blur_min - self.blur_max = blur_max - self.noise_prob = noise_prob - self.drop_prob = drop_prob - self.noise_sigma = noise_sigma - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - - target_np = np.array(item['map']) - heatmap = np.array(item['heatmap'], dtype=np.float32) - - # 数据增强 - if np.random.rand() > 0.5: - k = np.random.randint(0, 4) - target_np = np.rot90(target_np, k) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.rot90(heatmap[i], k) - - if np.random.rand() > 0.5: - target_np = np.fliplr(target_np) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.fliplr(heatmap[i]) - - if np.random.rand() > 0.5: - target_np = np.flipud(target_np) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.flipud(heatmap[i]) - - target = torch.LongTensor(target_np.copy()) # [H, W] - cond = torch.FloatTensor(item['val']) # [cond_dim] - - if random.random() < 0.5: - size = random.randint(self.blur_min, self.blur_max) - if size % 2 == 0: - size = size + 1 if random.random() < 0.5 else size - 1 - heatmap = cv2.GaussianBlur(heatmap, (size, size), 0) - else: - sizeX = random.randint(self.blur_min, self.blur_max) - sizeY = random.randint(self.blur_min, self.blur_max) - if sizeX % 2 == 0: - sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1 - if sizeY % 2 == 0: - sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 - heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0) - - heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] - - for i in range(0, heatmap.shape[0]): - if np.random.rand() < self.noise_prob: - sigma = random.random() * self.noise_sigma - heatmap[i] = heatmap[i] * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) - elif np.random.rand() < self.drop_prob: - heatmap[i] = torch.zeros_like(heatmap[i]) - - if random.random() < 0.5: - sigma = random.random() * self.sigma_rand - rand = torch.rand_like(heatmap) - heatmap = heatmap * (1 - sigma) + rand * sigma - - return { - "cond": cond, - "target_map": target, - "heatmap": heatmap - } - -class GinkaHeatmapDataset(Dataset): - def __init__(self, data_path: str, min_mask=0, max_mask=0.8, blur_min=3, blur_max=6): - self.data = load_data(data_path) - self.blur_min = blur_min - self.blur_max = blur_max - self.min_mask = min_mask - self.max_mask = max_mask - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - - heatmap = np.array(item['heatmap'], dtype=np.float32) - - # 数据增强 - if np.random.rand() > 0.5: - k = np.random.randint(0, 4) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.rot90(heatmap[i], k) - - if np.random.rand() > 0.5: - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.fliplr(heatmap[i]) - - if np.random.rand() > 0.5: - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.flipud(heatmap[i]) - - target = heatmap.copy() - - if random.random() < 0.5: - size = random.randint(self.blur_min, self.blur_max) - if size % 2 == 0: - size = size + 1 if random.random() < 0.5 else size - 1 - target = cv2.GaussianBlur(target, (size, size), 0) - else: - sizeX = random.randint(self.blur_min, self.blur_max) - sizeY = random.randint(self.blur_min, self.blur_max) - if sizeX % 2 == 0: - sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1 - if sizeY % 2 == 0: - sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 - target = cv2.GaussianBlur(target, (sizeX, sizeY), 0) - - target = torch.FloatTensor(target) # [heatmap_channel, H, W] - cond = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] - C, H, W = target.shape - - for i in range(C): - total = H * W - ratio = np.random.random() * (self.max_mask - self.min_mask) + self.min_mask - num = int(total * ratio) - - idx = np.random.choice(total, num, replace=False) - - mask = np.zeros(total, dtype=bool) - mask[idx] = True - mask = mask.reshape(H, W) - cond[i, mask] = 0 - - return { - "target_heatmap": heatmap, - "cond_heatmap": cond - } - - -class GinkaJointDataset(Dataset): - def __init__(self, data_path: str, min_mask=0, max_mask=0.8, blur_min=3, blur_max=6): - self.data = load_data(data_path) - self.blur_min = blur_min - self.blur_max = blur_max - self.min_mask = min_mask - self.max_mask = max_mask - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - - target_map = np.array(item['map']) - heatmap = np.array(item['heatmap'], dtype=np.float32) - - if np.random.rand() > 0.5: - k = np.random.randint(0, 4) - target_map = np.rot90(target_map, k) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.rot90(heatmap[i], k) - - if np.random.rand() > 0.5: - target_map = np.fliplr(target_map) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.fliplr(heatmap[i]) - - if np.random.rand() > 0.5: - target_map = np.flipud(target_map) - for i in range(0, heatmap.shape[0]): - heatmap[i] = np.flipud(heatmap[i]) - - target_heatmap = heatmap.copy() - - if random.random() < 0.5: - size = random.randint(self.blur_min, self.blur_max) - if size % 2 == 0: - size = size + 1 if random.random() < 0.5 else size - 1 - target_heatmap = cv2.GaussianBlur(target_heatmap, (size, size), 0) - else: - sizeX = random.randint(self.blur_min, self.blur_max) - sizeY = random.randint(self.blur_min, self.blur_max) - if sizeX % 2 == 0: - sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1 - if sizeY % 2 == 0: - sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 - target_heatmap = cv2.GaussianBlur(target_heatmap, (sizeX, sizeY), 0) - - target_map = torch.LongTensor(target_map.copy()) - target_heatmap = torch.FloatTensor(target_heatmap) - cond_heatmap = torch.FloatTensor(heatmap.copy()) - channels, height, width = cond_heatmap.shape - - for i in range(channels): - total = height * width - ratio = np.random.random() * (self.max_mask - self.min_mask) + self.min_mask - num = int(total * ratio) - - masked_indices = np.random.choice(total, num, replace=False) - - mask = np.zeros(total, dtype=bool) - mask[masked_indices] = True - mask = mask.reshape(height, width) - cond_heatmap[i, mask] = 0 - - return { - "target_map": target_map, - "target_heatmap": target_heatmap, - "cond_heatmap": cond_heatmap - } - - def _compute_symmetry(target_np: np.ndarray) -> tuple: """从 numpy 地图矩阵中直接计算三种对称性,O(H*W)""" sym_h = bool(np.all(target_np == target_np[:, ::-1])) @@ -254,21 +28,21 @@ class GinkaVQDataset(Dataset): 每次 __getitem__ 按权重随机选取以下四种子集之一: A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile - B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(15) + B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(6) C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile - D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(10),其余全部替换为 MASK(15) + D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(5),其余全部替换为 MASK(6) 返回 dict: raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码) - masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 15) + masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 6) target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map) subset: str 子集标识,供调试/统计用 """ FLOOR = 0 WALL = 1 - ENTRANCE = 10 - MASK_ID = 15 + ENTRANCE = 5 + MASK_ID = 6 def __init__( self, @@ -401,7 +175,7 @@ class GinkaVQDataset(Dataset): subset: 'A' | 'B' | 'C' | 'D' Returns: - [H*W] int64,被遮盖位置值为 MASK_ID(15) + [H*W] int64,被遮盖位置値为 MASK_ID(6) """ H, W = raw.shape @@ -434,7 +208,7 @@ class GinkaVQDataset(Dataset): return flat else: # D - # 仅保留 wall(1) 和 entrance(10),floor(0) 和其他非墙内容全部 mask + # 仅保留 wall(1) 和 entrance(5),floor(0) 和其他非墙内容全部 mask flat = raw.reshape(-1).copy() keep = (flat == self.WALL) | (flat == self.ENTRANCE) flat[~keep] = self.MASK_ID @@ -452,7 +226,6 @@ class GinkaVQDataset(Dataset): item = self.data[idx] raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W] - raw_np = remap_resources(raw_np) # 资源压缩 subset = self._choose_subset() masked_np = self._apply_subset(raw_np, subset) # [H*W] raw_flat = raw_np.reshape(-1) # [H*W] @@ -472,7 +245,7 @@ class GinkaVQDataset(Dataset): return { "raw_map": raw_t, # VQ-VAE 编码器输入 "slice1": make_slice(raw_t, {0, 1}), # 通道 1:floor+wall - "slice2": make_slice(raw_t, {0, 1, 2, 9, 10}),# 通道 2:floor+wall+门+怪+入口 + "slice2": make_slice(raw_t, {0, 1, 2, 4, 5}), # 通道 2:floor+wall+门+怪+入口 "slice3": raw_t.clone(), # 通道 3:完整地图 "masked_map": torch.LongTensor(masked_np), # MaskGIT 输入 "target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth @@ -515,7 +288,7 @@ class GinkaSplitDataset(Dataset): 每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。 切片按累积式设计: slice1 = floor(0) + wall(1) - slice2 = floor(0) + wall(1) + door(2) + mob(9) + entrance(10) + slice2 = floor(0) + wall(1) + door(2) + mob(4) + entrance(5) slice3 = 完整地图(所有 tile) 返回 dict: @@ -534,7 +307,6 @@ class GinkaSplitDataset(Dataset): def __getitem__(self, idx): item = self.data[idx] arr = np.array(item['map'], dtype=np.int64) # [H, W] - arr = remap_resources(arr) # 资源压缩 # 随机旋转 / 翻转数据增强 if np.random.rand() > 0.5: @@ -549,7 +321,7 @@ class GinkaSplitDataset(Dataset): return { "raw_map": raw, "slice1": make_slice(raw, {0, 1}), - "slice2": make_slice(raw, {0, 1, 2, 9, 10}), + "slice2": make_slice(raw, {0, 1, 2, 4, 5}), "slice3": raw.clone(), } @@ -571,5 +343,5 @@ if __name__ == "__main__": print(f"raw_map shape={raw.shape}, dtype={raw.dtype}") print(f"masked_map shape={masked.shape}, dtype={masked.dtype}") print(f"target_map shape={target.shape}, dtype={target.dtype}") - print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}") + print(f"被 mask 的位置数: {(masked == 6).sum().item()} / {masked.numel()}") print(f"\n200 次采样子集分布: {subset_count}") diff --git a/ginka/heatmap/cond.py b/ginka/heatmap/cond.py deleted file mode 100644 index 9e71a69..0000000 --- a/ginka/heatmap/cond.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn as nn - -class HeatmapCond(nn.Module): - def __init__(self, T=100, embed_dim=128, heatmap_dim=8, output_dim=128): - super().__init__() - self.time_embedding = nn.Embedding(T, embed_dim) - self.conv1 = nn.Sequential( - nn.Conv2d(heatmap_dim, output_dim // 4, 3, padding=1, padding_mode='replicate'), - nn.BatchNorm2d(output_dim // 4), - nn.GELU() - ) - self.conv2 = nn.Sequential( - nn.Conv2d(output_dim // 4, output_dim // 2, 3, padding=1, padding_mode='replicate'), - nn.BatchNorm2d(output_dim // 2), - nn.GELU() - ) - self.conv3 = nn.Sequential( - nn.Conv2d(output_dim // 2, output_dim, 3, padding=1, padding_mode='replicate') - ) - - self.fc1 = nn.Sequential( - nn.Linear(embed_dim, output_dim // 4), - nn.Dropout(0.3), - nn.LayerNorm(output_dim // 4), - nn.GELU() - ) - self.fc2 = nn.Sequential( - nn.Linear(embed_dim, output_dim // 2), - nn.Dropout(0.3), - nn.LayerNorm(output_dim // 2), - nn.GELU() - ) - self.fc3 = nn.Sequential( - nn.Linear(embed_dim, output_dim), - nn.Dropout(0.3), - nn.LayerNorm(output_dim), - nn.GELU() - ) - - def forward(self, heatmap: torch.Tensor, t: torch.Tensor): - # heatmap: [B, C, H, W] - # t: [B] - t_embed = self.time_embedding(t) - x = self.conv1(heatmap) + self.fc1(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2) - x = self.conv2(x) + self.fc2(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2) - x = self.conv3(x) + self.fc3(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2) - return x - \ No newline at end of file diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py deleted file mode 100644 index 6e3a53e..0000000 --- a/ginka/heatmap/diffusion.py +++ /dev/null @@ -1,63 +0,0 @@ -import math -import torch - -class Diffusion: - def __init__(self, device, T=100, noise_scale=0.5): - self.T = T - self.device = device - self.noise_scale = noise_scale - - # cosine schedule(推荐) - steps = torch.arange(T + 1, dtype=torch.float32) - s = 0.1 - f = torch.cos(((steps / (T + 1)) + s) / (1 + s) * math.pi * 0.5) ** 2 - alpha_bar = f / f[0] - - self.alpha_bar = alpha_bar.to(device) - self.sqrt_ab = torch.sqrt(self.alpha_bar) - self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar) - - def q_sample(self, x0, t, noise): - """ - 前向加噪:x_t = sqrt(αbar_t) * x0 + sqrt(1-αbar_t) * noise_scale * ε - noise_scale 降低噪声功率,使信号不被淹没 - """ - return ( - self.sqrt_ab[t][:, None, None, None] * x0 - + self.sqrt_one_minus_ab[t][:, None, None, None] * noise * self.noise_scale - ) - - def sample(self, model, cond: torch.Tensor, steps=20): - """ - DDIM 风格逆向采样,模型预测 x_0 - x_{t-1} = sqrt(αbar_{t-1}) * x0_pred - + sqrt(1-αbar_{t-1}) / sqrt(1-αbar_t) * (x_t - sqrt(αbar_t) * x0_pred) - """ - B = cond.shape[0] - # 初始噪声与前向过程保持一致的噪声功率 - x = torch.randn_like(cond).to(cond.device) * self.noise_scale - - step_size = self.T // steps - - for i in reversed(range(0, self.T, step_size)): - t = torch.full((B,), i, device=cond.device) - - # 模型直接预测 x_0 - x0_pred = model(x, cond, t) - - alpha = self.alpha_bar[i] - alpha_prev = self.alpha_bar[max(i - step_size, 0)] - - # DDIM x0-prediction 更新 - direction = ( - torch.sqrt(1 - alpha_prev) / torch.sqrt(1 - alpha) - ) * (x - torch.sqrt(alpha) * x0_pred) - - x = torch.sqrt(alpha_prev) * x0_pred + direction - - return x - -if __name__ == '__main__': - diff = Diffusion("cpu") - print(diff.sqrt_one_minus_ab) - print(diff.sqrt_ab) diff --git a/ginka/heatmap/model.py b/ginka/heatmap/model.py deleted file mode 100644 index e2ee83e..0000000 --- a/ginka/heatmap/model.py +++ /dev/null @@ -1,73 +0,0 @@ -import time -import torch -import torch.nn as nn -from .cond import HeatmapCond -from ..maskGIT.maskGIT import Transformer -from ..utils import print_memory - -class GinkaHeatmapModel(nn.Module): - def __init__( - self, T=100, embed_dim=128, heatmap_dim=8, d_model=128, dim_ff=512, nhead=8, - num_layers=4, map_size=13*13 - ): - super().__init__() - self.heatmap_dim = heatmap_dim - self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model)) - self.cond = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) - self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) - self.transformer = Transformer(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) - self.cross_attn = nn.MultiheadAttention(d_model, num_heads=nhead, batch_first=True) - self.output_fc = nn.Sequential( - nn.Linear(d_model, d_model // 2), - nn.LayerNorm(d_model // 2), - nn.Dropout(0.3), - nn.GELU(), - - nn.Linear(d_model // 2, heatmap_dim) - ) - - def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor): - # input: [B, heatmap_dim, H, W] 噪声 - # cond: [B, heatmap_dim, H, W] 点图 - # t: [B] - input = self.input(input, t) # [B, d_model, H, W] - cond = self.cond(cond, t) # [B, d_model, H, W] - B, C, H, W = input.shape - scale = torch.sigmoid(cond) # [B, d_model, H, W] - hidden = input * (1 + scale) + cond # [B, d_model, H, W] - hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] - hidden = hidden + self.pos_embedding # [B, H * W, d_model] - hidden = self.transformer(hidden) # [B, H * W, d_model] - cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] - attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) # [B, H * W, d_model] - hidden = hidden + attn # [B, H * W, d_model] - output = self.output_fc(hidden) # [B, H * W, heatmap_dim] - return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) # [B, heatmap_dim, H, W] - -if __name__ == "__main__": - device = torch.device("cpu") - - input = torch.randn(1, 9, 13, 13).to(device) - cond = torch.randint(0, 1, [1, 9, 13, 13]).to(device) - t = torch.randint(0, 100, [1]).to(device) - - # 初始化模型 - model = GinkaHeatmapModel(heatmap_dim=9).to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - output = model(input, cond.float(), t) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: output={output.shape}") - print(f"Tile Embedding parameters: {sum(p.numel() for p in model.cond.parameters())}") - print(f"Condition Encoder parameters: {sum(p.numel() for p in model.input.parameters())}") - print(f"MaskGIT parameters: {sum(p.numel() for p in model.transformer.parameters())}") - print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") - diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py deleted file mode 100644 index 97c65ac..0000000 --- a/ginka/train_heatmap.py +++ /dev/null @@ -1,280 +0,0 @@ -import argparse -import os -import sys -import math -from datetime import datetime -import torch -import torch.nn.functional as F -import torch.optim as optim -import cv2 -import numpy as np -from perlin_numpy import generate_fractal_noise_2d -from tqdm import tqdm -from torch.utils.data import DataLoader -from .maskGIT.model import GinkaMaskGIT -from .dataset import GinkaHeatmapDataset -from shared.image import matrix_to_image_cv -from .heatmap.model import GinkaHeatmapModel -from .heatmap.diffusion import Diffusion -from .utils import nms_sampling - -# 图块定义: -# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶 -# 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token - -# 热力图定义 -# 0. 墙壁热力图, 1. 怪物热力图, 2. 资源热力图, 3. 血瓶热力图, 4. 宝石热力图, 5. 钥匙热力图 -# 6. 道具热力图, 7. 入口热力图, 8. 门热力图 - -BATCH_SIZE = 128 -VAL_BATCH_DIVIDER = 64 -NUM_CLASSES = 16 -MASK_TOKEN = 15 -GENERATE_STEP = 8 -MAP_W = 13 -MAP_H = 13 -HEATMAP_CHANNEL = 9 -LABEL_SMOOTHING = 0 -BLUR_MIN_SIZE = 3 -BLUR_MAX_SIZE = 9 -RAND_RATIO = 0.15 -# MaskGIT 生成设置 -USE_MASK_GIT_PREVIEW = True -NUM_LAYERS = 4 -D_MODEL = 192 -# Diffusion 生成设置 -NUM_LAYERS_DIFFUSION = 4 -D_MODEL_DIFFUSION = 128 -T_DIFFUSION = 100 -MIN_MASK = 0 -MAX_MASK = 1 -NOISE_SCALE = 0.3 -W = 5 # CFG 参数 - -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.mps.is_available() - else "cpu" -) -os.makedirs("result", exist_ok=True) -os.makedirs("result/heatmap", exist_ok=True) -os.makedirs("result/final_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/heatmap/ginka-100.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - parser.add_argument("--use_maskgit", type=bool, default=True) - parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth") - args = parser.parse_args() - return args - -def train(): - print(f"Using {device.type} to train model.") - - args = parse_arguments() - - if args.use_maskgit: - maskGIT = GinkaMaskGIT( - num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, - num_layers=NUM_LAYERS, d_model=D_MODEL - ).to(device) - maskGIT.eval() - model = GinkaHeatmapModel( - T=T_DIFFUSION, heatmap_dim=HEATMAP_CHANNEL, d_model=D_MODEL_DIFFUSION, - num_layers=NUM_LAYERS_DIFFUSION - ).to(device) - - diffusion = Diffusion(device, noise_scale=NOISE_SCALE) - - dataset = GinkaHeatmapDataset(args.train, min_mask=MIN_MASK, max_mask=MAX_MASK) - dataset_val = GinkaHeatmapDataset(args.validate, min_mask=MIN_MASK, max_mask=MAX_MASK) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - - optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - # 接续训练 - if args.resume: - data_ginka = torch.load(args.state_ginka, map_location=device) - - model.load_state_dict(data_ginka["model_state"], strict=False) - - if args.load_optim: - if data_ginka.get("optim_state") is not None: - optimizer.load_state_dict(data_ginka["optim_state"]) - - print("Train from loaded state.") - - if args.use_maskgit: - data_maskGIT = torch.load(args.maskgit_path, map_location=device) - maskGIT.load_state_dict(data_maskGIT["model_state"]) - print("Loaded MaskGIT model state.") - - for epoch in tqdm(range(args.epochs), desc="Diffusion Training", disable=disable_tqdm): - loss_total = torch.Tensor([0]).to(device) - - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) - B, C, H, W = target_heatmap.shape - - optimizer.zero_grad() - - t = torch.randint(1, T_DIFFUSION, [B], device=device) - noise = torch.randn_like(target_heatmap) - - x_t = diffusion.q_sample(target_heatmap, t, noise) - - # CFG 随机概率没有输入条件 - if np.random.rand() < 0.2: - cond_heatmap = torch.zeros_like(cond_heatmap) - - # 模型预测 x_0,损失直接对齐热力图 - pred_x0 = model(x_t, cond_heatmap, t) - - loss = F.mse_loss(pred_x0, target_heatmap) - - loss.backward() - optimizer.step() - loss_total += loss.detach() - - scheduler.step() - - avg_loss = loss_total.item() / len(dataloader) - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | Loss: {avg_loss:.6f} | " + - f"LR: {scheduler.get_last_lr()[0]:.6f}" - ) - - # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % args.checkpoint == 0: - # 保存检查点 - torch.save({ - "model_state": model.state_dict(), - "optim_state": optimizer.state_dict(), - }, f"result/heatmap/ginka-{epoch + 1}.pth") - - val_loss_total = torch.Tensor([0]).to(device) - model.eval() - with torch.no_grad(): - idx = 0 - for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): - # 1. 验证集验证 - cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) - B, C, H, W = target_heatmap.shape - - t = torch.randint(1, T_DIFFUSION, [B], device=device) - noise = torch.randn_like(target_heatmap) - - x_t = diffusion.q_sample(target_heatmap, t, noise) - - pred_x0 = model(x_t, cond_heatmap, t) - - loss = F.mse_loss(pred_x0, target_heatmap) - - val_loss_total += loss.detach() - - # 2. 从头完整生成,并使用训练好的 MaskGIT 生成地图 - if args.use_maskgit: - map = full_generate(model, maskGIT, cond_heatmap, diffusion) - - generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict) - cv2.imwrite(f"result/final_img/{idx}.png", generated_img) - - idx += 1 - - # 3. 完全随机生成五张图 - if args.use_maskgit: - for i in range(0, 5): - ar = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W]) - k = get_nms_sampling_count() - for c in range(0, HEATMAP_CHANNEL): - noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H,0:MAP_W] - ar[0,c] = nms_sampling(noise, k[c]) - - map = full_generate(model, maskGIT, torch.FloatTensor(ar).to(device), diffusion) - generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict) - cv2.imwrite(f"result/final_img/g-{i}.png", generated_img) - - avg_loss_val = val_loss_total.item() / len(dataloader_val) - tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + - f"Loss: {avg_loss_val:.6f}" - ) - - print("Train ended.") - torch.save({ - "model_state": model.state_dict(), - }, f"result/ginka_heatmap.pth") - -def get_nms_sampling_count(): - return [ - np.random.randint(20, 40), - np.random.randint(10, 20), - np.random.randint(10, 30), - np.random.randint(4, 12), - np.random.randint(4, 12), - np.random.randint(2, 6), - np.random.randint(0, 2), - np.random.randint(1, 3), - np.random.randint(2, 10) - ] - -def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): - fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) - fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) - fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W] - return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) - -def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor): - # heatmap: [B, C, H, W] - map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device) - for i in range(GENERATE_STEP): - # 1. 预测 - logits = maskGIT(map, heatmap) # [1, H * W, num_classes] - probs = F.softmax(logits, dim=-1) - - # 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值) - dist = torch.distributions.Categorical(probs) - sampled_tiles = dist.sample() # [1, H * W] - - # 3. 计算置信度 (模型对采样结果的信心程度) - confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1) - - # 4. 决定本轮要固定多少个格子 (上凸函数逻辑) - ratio = math.cos(((i + 1) / GENERATE_STEP) * math.pi / 2) - num_to_mask = math.floor(ratio * MAP_H * MAP_W) - - # 5. 更新画布:保留置信度最高的部分,其余位置设回 MASK - # 注意:这里逻辑上通常是保留当前步预测中置信度最高的,并结合已有的非 mask 部分 - if num_to_mask > 0: - _, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False) - sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN) - - map = sampled_tiles - if (map == MASK_TOKEN).sum() == 0: - break - - return map - - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_joint.py b/ginka/train_joint.py deleted file mode 100644 index 96e29c7..0000000 --- a/ginka/train_joint.py +++ /dev/null @@ -1,397 +0,0 @@ -import argparse -import math -import os -import sys -from datetime import datetime - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -import torch.optim as optim -from perlin_numpy import generate_fractal_noise_2d -from torch.utils.data import DataLoader -from tqdm import tqdm - -from .dataset import GinkaJointDataset -from .heatmap.diffusion import Diffusion -from .heatmap.model import GinkaHeatmapModel -from .maskGIT.model import GinkaMaskGIT -from .utils import nms_sampling -from shared.image import matrix_to_image_cv - - -# 地图与 token 基础配置 -NUM_CLASSES = 16 -MASK_TOKEN = 15 -MAP_W = 13 -MAP_H = 13 -HEATMAP_CHANNEL = 9 -GENERATE_STEP = 8 - -# 训练批次与损失配置 -BATCH_SIZE = 64 -VAL_BATCH_DIVIDER = 64 -LABEL_SMOOTHING = 0 -CE_WEIGHT = 0.5 # 联合训练里 MaskGIT 监督项的权重 -DROP_RATE = 0.2 # CFG 训练时随机丢弃条件热力图的概率 - -# MaskGIT 模型结构 -NUM_LAYERS = 4 -D_MODEL = 192 - -# Diffusion 模型结构与噪声过程 -NUM_LAYERS_DIFFUSION = 4 -D_MODEL_DIFFUSION = 128 -T_DIFFUSION = 100 -MIN_MASK = 0 -MAX_MASK = 1 -NOISE_SCALE = 0.3 - -# 验证预览配置 -PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度 -RANDOM_PREVIEW_COUNT = 5 # 每次验证额外生成的随机预览数量 - -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.mps.is_available() - else "cpu" -) -os.makedirs("result", exist_ok=True) -os.makedirs("result/joint", exist_ok=True) -os.makedirs("result/joint_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - - -def parse_arguments(): - # 解析联合训练脚本的命令行参数。 - parser = argparse.ArgumentParser(description="joint training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=50) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth") - args = parser.parse_args() - return args - - -def load_heatmap_checkpoint(model, optimizer, args): - # 加载预训练 Diffusion 权重,并在需要时恢复优化器状态。 - if not args.state_heatmap: - return - - if not os.path.exists(args.state_heatmap): - raise FileNotFoundError(f"Heatmap checkpoint not found: {args.state_heatmap}") - - checkpoint = torch.load(args.state_heatmap, map_location=device) - model.load_state_dict(checkpoint["model_state"], strict=False) - - if args.resume and args.load_optim and checkpoint.get("optim_state") is not None: - optimizer.load_state_dict(checkpoint["optim_state"]) - - print("Loaded Diffusion model state.") - - -def freeze_module(module: torch.nn.Module): - # 冻结模块参数,使其在联合训练中只作为固定监督器使用。 - module.eval() - for parameter in module.parameters(): - parameter.requires_grad = False - - -def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): - # 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。 - batch_size, height, width = target_map.shape - target_tokens = target_map.view(batch_size, height * width) - canvas = torch.full_like(target_tokens, MASK_TOKEN) - losses = [] - - for step in range(GENERATE_STEP): - current_mask = canvas == MASK_TOKEN - if current_mask.sum().item() == 0: - break - - # 保证前向传播可导 - logits = maskgit(canvas, generated_heatmap) - ce = F.cross_entropy( - logits.permute(0, 2, 1), - target_tokens, - label_smoothing=LABEL_SMOOTHING - ) - losses.append(ce) - - with torch.no_grad(): - probs = F.softmax(logits, dim=-1) - sampled_tiles = torch.argmax(probs, dim=-1) - confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1) - - ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2) - num_to_mask = math.floor(ratio * target_tokens.shape[1]) - - if num_to_mask > 0: - _, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False) - sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN) - - canvas = sampled_tiles - - if not losses: - return torch.zeros((), device=generated_heatmap.device) - - return torch.stack(losses).mean() - - -def load_tile_dict(): - # 加载用于可视化地图的图块贴图。 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - return tile_dict - - -def get_nms_sampling_count(): - # 为随机点图预览采样每个通道的点数量。 - return [ - np.random.randint(20, 40), - np.random.randint(10, 20), - np.random.randint(10, 30), - np.random.randint(4, 12), - np.random.randint(4, 12), - np.random.randint(2, 6), - np.random.randint(0, 2), - np.random.randint(1, 3), - np.random.randint(2, 10) - ] - - -def maskgit_generate(maskgit, batch_size: int, heatmap: torch.Tensor): - # 使用冻结的 MaskGIT 把热力图解码为完整地图。 - generated_map = torch.full((batch_size, MAP_H * MAP_W), MASK_TOKEN, device=device) - for step in range(GENERATE_STEP): - logits = maskgit(generated_map, heatmap) - probs = F.softmax(logits, dim=-1) - - dist = torch.distributions.Categorical(probs) - sampled_tiles = dist.sample() - confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1) - - ratio = math.cos(((step + 1) / GENERATE_STEP) * math.pi / 2) - num_to_mask = math.floor(ratio * MAP_H * MAP_W) - - if num_to_mask > 0: - _, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False) - sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN) - - generated_map = sampled_tiles - if (generated_map == MASK_TOKEN).sum() == 0: - break - - return generated_map - - -def full_generate(heatmap_model, maskgit, cond_heatmap: torch.Tensor, diffusion: Diffusion): - # 执行完整预览生成流程:点图 -> 热力图 -> 地图。 - fake_heatmap_cond = diffusion.sample(heatmap_model, cond_heatmap) - fake_heatmap_uncond = diffusion.sample(heatmap_model, torch.zeros_like(cond_heatmap)) - fake_heatmap = fake_heatmap_uncond + PREVIEW_CFG_WEIGHT * (fake_heatmap_uncond - fake_heatmap_cond) - return maskgit_generate(maskgit, cond_heatmap.shape[0], fake_heatmap) - - -def save_random_previews(model, maskgit, diffusion, tile_dict): - # 额外生成随机点图预览,便于观察模型的开放式生成效果。 - for preview_idx in range(RANDOM_PREVIEW_COUNT): - cond_array = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W]) - sampling_count = get_nms_sampling_count() - for channel in range(HEATMAP_CHANNEL): - noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H, 0:MAP_W] - cond_array[0, channel] = nms_sampling(noise, sampling_count[channel]) - - generated_map = full_generate(model, maskgit, torch.FloatTensor(cond_array).to(device), diffusion) - generated_img = matrix_to_image_cv(generated_map.view(1, MAP_H, MAP_W)[0].cpu().numpy(), tile_dict) - cv2.imwrite(f"result/joint_img/g-{preview_idx}.png", generated_img) - - -def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict): - # 执行数值验证,并保存生成地图预览图。 - model.eval() - total_loss = 0.0 - total_diffusion_loss = 0.0 - total_maskgit_loss = 0.0 - - with torch.no_grad(): - preview_idx = 0 - for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm): - cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) - target_map = batch["target_map"].to(device) - batch_size, _, map_height, map_width = target_heatmap.shape - - t = torch.randint(1, T_DIFFUSION, [batch_size], device=device) - noise = torch.randn_like(target_heatmap) - x_t = diffusion.q_sample(target_heatmap, t, noise) - - pred_x0 = model(x_t, cond_heatmap, t) - diffusion_loss = F.mse_loss(pred_x0, target_heatmap) - - maskgit_loss = maskgit_joint_loss(maskgit, pred_x0, target_map) - - loss = diffusion_loss + ce_weight * maskgit_loss - total_loss += loss.item() - total_diffusion_loss += diffusion_loss.item() - total_maskgit_loss += maskgit_loss.item() - - # 预览生成结果 - generated_map = full_generate(model, maskgit, cond_heatmap, diffusion) - generated_img = matrix_to_image_cv( - generated_map.view(batch_size, map_height, map_width)[0].cpu().numpy(), - tile_dict, - ) - cv2.imwrite(f"result/joint_img/{preview_idx}.png", generated_img) - preview_idx += 1 - - save_random_previews(model, maskgit, diffusion, tile_dict) - - size = max(len(dataloader), 1) - return { - "loss": total_loss / size, - "diffusion_loss": total_diffusion_loss / size, - "maskgit_loss": total_maskgit_loss / size, - } - - -def train(): - # 联合训练 Diffusion,使其同时受到噪声重建和冻结 MaskGIT 的监督。 - print(f"Using {device.type} to train model.") - - args = parse_arguments() - tile_dict = load_tile_dict() - - maskgit = GinkaMaskGIT( - num_classes=NUM_CLASSES, - heatmap_channel=HEATMAP_CHANNEL, - num_layers=NUM_LAYERS, - d_model=D_MODEL, - ).to(device) - if not os.path.exists(args.maskgit_path): - raise FileNotFoundError(f"MaskGIT checkpoint not found: {args.maskgit_path}") - maskgit_state = torch.load(args.maskgit_path, map_location=device) - maskgit.load_state_dict(maskgit_state["model_state"]) - freeze_module(maskgit) - print("Loaded and froze MaskGIT model state.") - - model = GinkaHeatmapModel( - T=T_DIFFUSION, - heatmap_dim=HEATMAP_CHANNEL, - d_model=D_MODEL_DIFFUSION, - num_layers=NUM_LAYERS_DIFFUSION, - ).to(device) - diffusion = Diffusion(device, T=T_DIFFUSION, noise_scale=NOISE_SCALE) - - dataset = GinkaJointDataset(args.train, min_mask=MIN_MASK, max_mask=MAX_MASK) - dataset_val = GinkaJointDataset(args.validate, min_mask=MIN_MASK, max_mask=MAX_MASK) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader( - dataset_val, - batch_size=max(1, BATCH_SIZE // VAL_BATCH_DIVIDER), - shuffle=True, - ) - - optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=args.epochs, - eta_min=1e-6, - ) - - load_heatmap_checkpoint(model, optimizer, args) - - for epoch in tqdm(range(args.epochs), desc="Joint Training", disable=disable_tqdm): - model.train() - epoch_loss = 0.0 - epoch_diffusion_loss = 0.0 - epoch_maskgit_loss = 0.0 - - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) - target_map = batch["target_map"].to(device) - batch_size = target_heatmap.shape[0] - - optimizer.zero_grad() - - t = torch.randint(1, T_DIFFUSION, [batch_size], device=device) - noise = torch.randn_like(target_heatmap) - x_t = diffusion.q_sample(target_heatmap, t, noise) - - cond_for_diffusion = cond_heatmap - use_unconditional_branch = False - if np.random.rand() < DROP_RATE: - cond_for_diffusion = torch.zeros_like(cond_heatmap) - use_unconditional_branch = True - - pred_x0 = model(x_t, cond_for_diffusion, t) - diffusion_loss = F.mse_loss(pred_x0, target_heatmap) - - # 若使用无条件分支,重新对有条件输入预测以计算联合损失 - pred_x0_for_joint = pred_x0 - if use_unconditional_branch: - pred_x0_for_joint = model(x_t, cond_heatmap, t) - - maskgit_loss = maskgit_joint_loss(maskgit, pred_x0_for_joint, target_map) - - loss = diffusion_loss + CE_WEIGHT * maskgit_loss - loss.backward() - optimizer.step() - - epoch_loss += loss.item() - epoch_diffusion_loss += diffusion_loss.item() - epoch_maskgit_loss += maskgit_loss.item() - - scheduler.step() - - train_size = max(len(dataloader), 1) - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"E: {epoch + 1} | " - f"Loss: {epoch_loss / train_size:.6f} | " - f"Diffusion: {epoch_diffusion_loss / train_size:.6f} | " - f"MaskGIT: {epoch_maskgit_loss / train_size:.6f} | " - f"LR: {scheduler.get_last_lr()[0]:.6f}" - ) - - if (epoch + 1) % args.checkpoint == 0: - checkpoint_path = f"result/joint/ginka-joint-{epoch + 1}.pth" - torch.save( - { - "model_state": model.state_dict(), - "optim_state": optimizer.state_dict(), - }, - checkpoint_path, - ) - - metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT, tile_dict) - tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"E: {epoch + 1} | " - f"Loss: {metrics['loss']:.6f} | " - f"Diffusion: {metrics['diffusion_loss']:.6f} | " - f"MaskGIT: {metrics['maskgit_loss']:.6f}" - ) - - print("Train ended.") - torch.save( - { - "model_state": model.state_dict(), - "optim_state": optimizer.state_dict(), - }, - "result/ginka_joint_heatmap.pth", - ) - - -if __name__ == "__main__": - torch.set_num_threads(4) - train() \ No newline at end of file diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py deleted file mode 100644 index 7cbbe5b..0000000 --- a/ginka/train_maskGIT.py +++ /dev/null @@ -1,244 +0,0 @@ -import argparse -import os -import sys -import random -import math -from datetime import datetime -import torch -import torch.nn.functional as F -import torch.optim as optim -import cv2 -import numpy as np -from tqdm import tqdm -from torch.utils.data import DataLoader -from .maskGIT.model import GinkaMaskGIT -from .dataset import GinkaMaskGITDataset -from shared.image import matrix_to_image_cv -from .maskGIT.mask import MapMask - -# 标量值定义: -# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 -# 1. 墙体密度,墙壁/地图面积 -# 2. 门密度,门数量/地图面积 -# 3. 怪物密度,怪物数量/地图面积 -# 4. 资源密度,资源数量/地图面积 -# 5. 宝石密度,宝石数量/地图面积 -# 6. 血瓶密度,血瓶数量/地图面积 -# 7. 钥匙密度,钥匙数量/地图面积 -# 8. 道具密度,道具数量/地图面积 -# 9. 入口数量 - -# 图块定义: -# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶 -# 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token - -# 热力图定义 -# 0. 墙壁热力图, 1. 怪物热力图, 2. 资源热力图, 3. 血瓶热力图, 4. 宝石热力图, 5. 钥匙热力图 -# 6. 道具热力图, 7. 入口热力图, 8. 门热力图 - -BATCH_SIZE = 128 -VAL_BATCH_DIVIDER = 64 -NUM_CLASSES = 16 -MASK_TOKEN = 15 -GENERATE_STEP = 8 -MAP_SIZE = 13 * 13 -HEATMAP_CHANNEL = 9 -LABEL_SMOOTHING = 0 -BLUR_MIN_SIZE = 3 -BLUR_MAX_SIZE = 9 -RAND_RATIO = 0.3 -MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 -NUM_LAYERS = 4 -D_MODEL = 192 - -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.mps.is_available() - else "cpu" -) -os.makedirs("result", exist_ok=True) -os.makedirs("result/transformer", exist_ok=True) -os.makedirs("result/transformer_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/transformer/ginka-100.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - args = parser.parse_args() - return args - -def train(): - print(f"Using {device.type} to train model.") - - args = parse_arguments() - - model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL).to(device) - masker = MapMask([0.5, 0.5]) - - dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE) - dataset_val = GinkaMaskGITDataset(args.validate, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - - optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - # 接续训练 - if args.resume: - data_ginka = torch.load(args.state_ginka, map_location=device) - - model.load_state_dict(data_ginka["model_state"], strict=False) - - if args.load_optim: - if data_ginka.get("optim_state") is not None: - optimizer.load_state_dict(data_ginka["optim_state"]) - - print("Train from loaded state.") - - for epoch in tqdm(range(args.epochs), desc="MaskGIT Training", disable=disable_tqdm): - loss_total = torch.Tensor([0]).to(device) - - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - target_map = batch["target_map"].to(device) - heatmap = batch["heatmap"].to(device) - B, H, W = target_map.shape - - target_map = target_map.view(B, H * W) - - mask = np.zeros((B, H * W)) - for i in range(B): - mask[i] = masker.mask(H, W) - - mask = torch.from_numpy(mask).to(torch.bool).to(device) - - # 掩码 - masked_input = target_map.clone() - masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记 - - logits = model(masked_input, heatmap) - - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=LABEL_SMOOTHING) - loss = (loss * mask).sum() / (mask.sum() + 1e-6) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - loss_total += loss.detach() - - scheduler.step() - - avg_loss = loss_total.item() / len(dataloader) - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | Loss: {avg_loss:.6f} | " + - f"LR: {scheduler.get_last_lr()[0]:.6f}" - ) - - # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % args.checkpoint == 0: - # 保存检查点 - torch.save({ - "model_state": model.state_dict(), - "optim_state": optimizer.state_dict(), - }, f"result/transformer/ginka-{epoch + 1}.pth") - - val_loss_total = torch.Tensor([0]).to(device) - model.eval() - with torch.no_grad(): - idx = 0 - gap = 5 - color = (255, 255, 255) # 白色 - vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 - for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): - # 1. 常规生成 - target_map = batch["target_map"].to(device) - heatmap = batch["heatmap"].to(device) - B, H, W = target_map.shape - target_map = target_map.view(B, H * W) - - mask = np.zeros((B, H * W)) - for i in range(B): - mask[i] = masker.mask(H, W) - - mask = torch.from_numpy(mask).to(torch.bool).to(device) - - # 2. 生成掩码矩阵 - masked_input = target_map.clone() - masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记 - - logits = model(masked_input, heatmap) - - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=LABEL_SMOOTHING) - loss = (loss * mask).sum() / (mask.sum() + 1e-6) - - val_loss_total += loss.detach() - - fake_map = torch.argmax(logits, dim=2).view(B, H, W).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - real_map = target_map.view(B, H, W).cpu().numpy() - real_img = matrix_to_image_cv(real_map[0], tile_dict) - img = np.block([[real_img], [vline], [fake_img]]) - cv2.imwrite(f"result/transformer_img/{idx}.png", img) - - idx += 1 - - # 2. 从头完整生成 - map = torch.full((B, MAP_SIZE), MASK_TOKEN).to(device) - for i in range(GENERATE_STEP): - # 1. 预测 - logits = model(map, heatmap) # [1, H * W, num_classes] - probs = F.softmax(logits, dim=-1) - - # 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值) - dist = torch.distributions.Categorical(probs) - sampled_tiles = dist.sample() # [1, H * W] - - # 3. 计算置信度 (模型对采样结果的信心程度) - confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1) - - # 4. 决定本轮要固定多少个格子 (上凸函数逻辑) - ratio = math.cos(((i + 1) / GENERATE_STEP) * math.pi / 2) - num_to_mask = math.floor(ratio * MAP_SIZE) - - # 5. 更新画布:保留置信度最高的部分,其余位置设回 MASK - # 注意:这里逻辑上通常是保留当前步预测中置信度最高的,并结合已有的非 mask 部分 - if num_to_mask > 0: - _, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False) - sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN) - - map = sampled_tiles - if (map == MASK_TOKEN).sum() == 0: - break - - generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict) - img = np.block([[real_img], [vline], [generated_img]]) - cv2.imwrite(f"result/transformer_img/g-{idx}.png", img) - - avg_loss_val = val_loss_total.item() / len(dataloader_val) - tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + - f"Loss: {avg_loss_val:.6f}" - ) - - print("Train ended.") - torch.save({ - "model_state": model.state_dict(), - }, f"result/ginka_transformer.pth") - - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_pretrain.py b/ginka/train_pretrain.py deleted file mode 100644 index ab2df99..0000000 --- a/ginka/train_pretrain.py +++ /dev/null @@ -1,341 +0,0 @@ -""" -VQ 编码器预训练脚本(方案 D) - -目标:在联合训练开始前,先单独预训练 VQ 编码器,使其学到地图的大致语义分类。 -解码头(VQDecodeHead)仅在预训练阶段使用,结束后丢弃,权重不迁移到联合训练。 - -训练流程(对应设计文档方案 D 三阶段): - 阶段 0(本脚本):编码器 + 临时解码头,全图重建目标 - 阶段 1(在 train_vq.py 中):编码器冻结 + MaskGIT 热身,启用 --freeze_vq - 阶段 2(在 train_vq.py 中):完整联合训练,编码器用较小 LR - -用法示例: - python -m ginka.train_pretrain - python -m ginka.train_pretrain --resume True --state result/pretrain/pretrain-20.pth - # 预训练完成后,传入权重路径启动联合训练阶段 1: - python -m ginka.train_vq --resume True --state result/pretrain/pretrain_final.pth -""" - -import argparse -import os -import sys -from datetime import datetime - -import numpy as np -import torch -import torch.nn.functional as F -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset -from tqdm import tqdm - -from .vqvae.model import GinkaVQVAE, VQDecodeHead -from .dataset import load_data - -# --------------------------------------------------------------------------- -# 超参数(须与 train_vq.py 中 VQ-VAE 配置保持一致) -# --------------------------------------------------------------------------- -BATCH_SIZE = 64 -NUM_CLASSES = 16 -MAP_SIZE = 13 * 13 -MAP_H = MAP_W = 13 - -# VQ-VAE 超参(保持与 train_vq.py 一致) -VQ_L = 2 -VQ_K = 8 -VQ_D_Z = 128 -VQ_D_MODEL= 192 -VQ_NHEAD = 8 -VQ_LAYERS = 4 -VQ_DIM_FF = 512 -VQ_BETA = 0.5 -VQ_GAMMA = 0.0 - -# Focal Loss -FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别) - -# 解码头超参(与编码器对称:同等层数和 FFN 宽度) -DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除) -DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致) -DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致) - -# --------------------------------------------------------------------------- -# 设备 -# --------------------------------------------------------------------------- -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() - else "cpu" -) - -os.makedirs("result/pretrain", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -# --------------------------------------------------------------------------- -# Focal Loss -# --------------------------------------------------------------------------- -def focal_loss( - logits: torch.Tensor, - targets: torch.Tensor, - gamma: float = FOCAL_GAMMA, -) -> torch.Tensor: - """ - 多分类 Focal Loss(mean 归约):FL = -(1 - p_t)^gamma * log(p_t) - - 相比 CE,对已被正确分类的高置信度样本施加更小的权重, - 迫使模型关注难分类的稀有 tile(门/怪/资源等)。 - """ - ce = F.cross_entropy(logits, targets, reduction='none') - pt = torch.exp(-ce) - return ((1.0 - pt) ** gamma * ce).mean() - -# --------------------------------------------------------------------------- -# 简单数据集:仅返回 raw_map,无子集划分,无掩码 -# --------------------------------------------------------------------------- -class GinkaPretrainDataset(Dataset): - """ - 预训练专用数据集,仅提供完整原始地图(raw_map)和随机数据增强。 - - 不做子集划分与掩码处理;重建目标为全图所有 169 个位置。 - """ - - def __init__(self, data_path: str): - self.data = load_data(data_path) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - arr = np.array(item['map'], dtype=np.int64) # [H, W] - - # 随机旋转 / 翻转数据增强 - if np.random.rand() > 0.5: - k = np.random.randint(1, 4) - arr = np.rot90(arr, k).copy() - if np.random.rand() > 0.5: - arr = np.fliplr(arr).copy() - if np.random.rand() > 0.5: - arr = np.flipud(arr).copy() - - raw_map = torch.tensor(arr.reshape(-1), dtype=torch.long) # [H*W] - return raw_map - -# --------------------------------------------------------------------------- -# 参数解析 -# --------------------------------------------------------------------------- -def parse_arguments(): - parser = argparse.ArgumentParser(description="VQ 编码器预训练(方案 D)") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state", type=str, default="result/pretrain/pretrain-20.pth", - help="续训时加载的检查点路径") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=50) - parser.add_argument("--checkpoint", type=int, default=5, - help="每隔多少 epoch 保存检查点并输出验证指标") - parser.add_argument("--load_optim", type=bool, default=True) - return parser.parse_args() - -# --------------------------------------------------------------------------- -# 验证:计算全图 top-1 准确率及关键类别(墙壁)召回率 -# --------------------------------------------------------------------------- -@torch.no_grad() -def validate( - model_vq: GinkaVQVAE, - decode_head: VQDecodeHead, - dataloader_val: DataLoader, -) -> dict: - model_vq.eval() - decode_head.eval() - - total, correct = 0, 0 - wall_tp, wall_gt = 0, 0 # wall tile=1 的召回 - class_correct = torch.zeros(NUM_CLASSES, dtype=torch.long) - class_total = torch.zeros(NUM_CLASSES, dtype=torch.long) - - for raw_map in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): - raw_map = raw_map.to(device) # [B, H*W] - - z_q, _, _, _, _, _ = model_vq(raw_map) - logits = decode_head(z_q) # [B, H*W, C] - pred = logits.argmax(dim=-1) # [B, H*W] - - correct += (pred == raw_map).sum().item() - total += raw_map.numel() - - # 墙壁召回 - wall_mask = (raw_map == 1) - wall_tp += (pred[wall_mask] == 1).sum().item() - wall_gt += wall_mask.sum().item() - - # 逐类别统计 - for c in range(NUM_CLASSES): - mask_c = (raw_map == c) - class_correct[c] += (pred[mask_c] == c).sum().item() - class_total[c] += mask_c.sum().item() - - acc = correct / max(total, 1) - wall_rec = wall_tp / max(wall_gt, 1) - - # 有样本的类别逐一统计 - per_class = {} - for c in range(NUM_CLASSES): - if class_total[c] > 0: - per_class[c] = class_correct[c].item() / class_total[c].item() - - return {"acc": acc, "wall_recall": wall_rec, "per_class": per_class} - -# --------------------------------------------------------------------------- -# 主训练函数 -# --------------------------------------------------------------------------- -def train(): - print(f"Using device: {device}") - args = parse_arguments() - - # ---- 模型 ---- - model_vq = GinkaVQVAE( - num_classes=NUM_CLASSES, - L=VQ_L, K=VQ_K, d_z=VQ_D_Z, - d_model=VQ_D_MODEL, nhead=VQ_NHEAD, - num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, - map_size=MAP_SIZE, - beta=VQ_BETA, gamma=VQ_GAMMA, - ).to(device) - - decode_head = VQDecodeHead( - num_classes=NUM_CLASSES, - d_z=VQ_D_Z, - map_size=MAP_SIZE, - nhead=DH_NHEAD, - dim_ff=DH_DIM_FF, - num_layers=DH_LAYERS, - ).to(device) - - vq_params = sum(p.numel() for p in model_vq.parameters()) - dh_params = sum(p.numel() for p in decode_head.parameters()) - print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)") - print(f"DecodeHead 参数量: {dh_params:,} ({dh_params/1e6:.3f}M)") - - # ---- 数据集 ---- - dataset_train = GinkaPretrainDataset(args.train) - dataset_val = GinkaPretrainDataset(args.validate) - dataloader_train = DataLoader( - dataset_train, batch_size=BATCH_SIZE, shuffle=True, - num_workers=0, pin_memory=(device.type == "cuda"), - ) - dataloader_val = DataLoader( - dataset_val, batch_size=BATCH_SIZE, shuffle=False, - num_workers=0, - ) - print(f"训练集: {len(dataset_train)} 条 验证集: {len(dataset_val)} 条") - - # ---- 优化器 ---- - all_params = list(model_vq.parameters()) + list(decode_head.parameters()) - optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs, eta_min=1e-6 - ) - - # ---- 续训 ---- - start_epoch = 0 - if args.resume: - ckpt = torch.load(args.state, map_location=device) - model_vq.load_state_dict(ckpt["vq_state"], strict=False) - if "dh_state" in ckpt: - decode_head.load_state_dict(ckpt["dh_state"], strict=False) - if args.load_optim and ckpt.get("optim_state") is not None: - optimizer.load_state_dict(ckpt["optim_state"]) - start_epoch = ckpt.get("epoch", 0) - print(f"从 epoch {start_epoch} 接续训练。") - - # ---- 训练循环 ---- - for epoch in tqdm(range(start_epoch, start_epoch + args.epochs), - desc="VQ Pretrain", disable=disable_tqdm): - model_vq.train() - decode_head.train() - - loss_total = 0.0 - ce_total = 0.0 - commit_total = 0.0 - entropy_total = 0.0 - - for raw_map in tqdm(dataloader_train, leave=False, - desc="Epoch Progress", disable=disable_tqdm): - raw_map = raw_map.to(device) # [B, H*W] - - # 1. 编码 - z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) - - # 2. 解码→全图重建(focal loss 缓解墙壁/空地主导问题) - logits = decode_head(z_q) # [B, H*W, C] - ce_loss = focal_loss(logits.permute(0, 2, 1), raw_map) - - # 3. 总损失(重建 + VQ 正则) - loss = ce_loss + vq_loss - - optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) - optimizer.step() - - loss_total += loss.detach().item() - ce_total += ce_loss.detach().item() - commit_total += commit_loss.detach().item() - entropy_total += entropy_loss.detach().item() - - scheduler.step() - - n = len(dataloader_train) - tqdm.write( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"Epoch {epoch + 1:4d} | " - f"Loss {loss_total/n:.5f} " - f"Focal {ce_total/n:.5f} " - f"Commit {commit_total/n:.5f} " - f"Entropy {entropy_total/n:.5f} | " - f"LR {scheduler.get_last_lr()[0]:.6f}" - ) - - # ---- 检查点 + 验证 ---- - if (epoch + 1) % args.checkpoint == 0: - ckpt_path = f"result/pretrain/pretrain-{epoch + 1}.pth" - torch.save({ - "epoch": epoch + 1, - "vq_state": model_vq.state_dict(), - "dh_state": decode_head.state_dict(), - "optim_state": optimizer.state_dict(), - }, ckpt_path) - tqdm.write(f" 检查点已保存: {ckpt_path}") - - metrics = validate(model_vq, decode_head, dataloader_val) - acc_str = f" [Validate] Acc {metrics['acc']:.4f} Wall Recall {metrics['wall_recall']:.4f}" - - # 输出有样本的类别准确率 - pc = metrics["per_class"] - detail = " ".join( - f"c{c}={v:.3f}" for c, v in sorted(pc.items()) if v < 1.0 - ) - if detail: - acc_str += f"\n Per-class: {detail}" - tqdm.write(acc_str) - - model_vq.train() - decode_head.train() - - # ---- 保存最终 VQ 编码器权重 ---- - final_path = "result/pretrain/pretrain_final.pth" - torch.save({ - "epoch": start_epoch + args.epochs, - "vq_state": model_vq.state_dict(), - # 不保存解码头:联合训练阶段不需要 - }, final_path) - print(f"\n预训练完成。编码器权重已保存至: {final_path}") - print(f"联合训练阶段 1 启动命令(编码器冻结热身):") - print(f" python -m ginka.train_vq --resume True --state {final_path} --freeze_vq True") - - -# --------------------------------------------------------------------------- -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index 79fd350..ca6033a 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -35,7 +35,7 @@ from .utils import masked_focal # 超参数 # --------------------------------------------------------------------------- BATCH_SIZE = 64 -NUM_CLASSES = 16 +NUM_CLASSES = 7 MAP_SIZE = 13 * 13 FOCAL_GAMMA = 2.0 @@ -46,14 +46,14 @@ CH1_D_MODEL = 64 CH1_NHEAD = 8 # 通道 2:关卡门控 -CH2_KEEP = {0, 1, 2, 9, 10} -CH2_LOSS = {0, 1, 2, 9, 10} +CH2_KEEP = {0, 1, 2, 4, 5} +CH2_LOSS = {0, 1, 2, 4, 5} CH2_D_MODEL = 64 CH2_NHEAD = 8 # 通道 3:收集资源 CH3_KEEP = None # 完整地图,无需切片 -CH3_LOSS = {0, 1, 2, 3, 9, 10} +CH3_LOSS = {0, 1, 2, 3, 4, 5} CH3_D_MODEL = 64 CH3_NHEAD = 8 @@ -125,9 +125,9 @@ def validate( # 每类 tile 的 tp / gt 计数 ch1_tp, ch1_gt = 0, 0 # wall(1) - ch2_tp = {t: 0 for t in CH2_LOSS} # {2,9,10} + ch2_tp = {t: 0 for t in CH2_LOSS} # {2,4,5} ch2_gt = {t: 0 for t in CH2_LOSS} - ch3_tp = {t: 0 for t in CH3_LOSS} # {3,4,5,6,7,8} + ch3_tp = {t: 0 for t in CH3_LOSS} # {3,4,5} ch3_gt = {t: 0 for t in CH3_LOSS} # codebook 使用频次(用于熵估算) diff --git a/ginka/train_vq.py b/ginka/train_vq.py index e6c0e0f..f7d3f70 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -37,8 +37,8 @@ from shared.image import matrix_to_image_cv # 超参数 # --------------------------------------------------------------------------- BATCH_SIZE = 64 -NUM_CLASSES = 16 -MASK_TOKEN = 15 +NUM_CLASSES = 7 +MASK_TOKEN = 6 GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数 MAP_SIZE = 13 * 13 MAP_H = MAP_W = 13 @@ -61,7 +61,7 @@ VQ_DIM_FF = 512 # 通道专属损失计算范围(用于监控验证召回率) CH1_LOSS = {1} -CH2_LOSS = {2, 9, 10} +CH2_LOSS = {2, 4, 5} CH3_LOSS = {3} # 资源已压缩为单一 tile=3 # MaskGIT 超参 diff --git a/tiles/10.png b/tiles/10.png deleted file mode 100644 index d2eb533..0000000 Binary files a/tiles/10.png and /dev/null differ diff --git a/tiles/15.png b/tiles/15.png deleted file mode 100644 index eb62785..0000000 Binary files a/tiles/15.png and /dev/null differ diff --git a/tiles/3.png b/tiles/3.png index 339c1c3..08409ab 100644 Binary files a/tiles/3.png and b/tiles/3.png differ diff --git a/tiles/4.png b/tiles/4.png index 08409ab..1329097 100644 Binary files a/tiles/4.png and b/tiles/4.png differ diff --git a/tiles/5.png b/tiles/5.png index 792ed88..d2eb533 100644 Binary files a/tiles/5.png and b/tiles/5.png differ diff --git a/tiles/6.png b/tiles/6.png index 4b8d3a6..eb62785 100644 Binary files a/tiles/6.png and b/tiles/6.png differ diff --git a/tiles/7.png b/tiles/7.png deleted file mode 100644 index b121323..0000000 Binary files a/tiles/7.png and /dev/null differ diff --git a/tiles/8.png b/tiles/8.png deleted file mode 100644 index 38d7a35..0000000 Binary files a/tiles/8.png and /dev/null differ diff --git a/tiles/9.png b/tiles/9.png deleted file mode 100644 index 1329097..0000000 Binary files a/tiles/9.png and /dev/null differ