From 79cf3ab2262b3151b9de8a1f56ff760b6888fb5a Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 19 Jan 2026 22:30:00 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20vae=20=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 225 +++++++++++++++++++++++++++++++++++++ ginka/utils.py | 7 ++ ginka/vae_rnn/decoder.py | 236 +++++++++++++++++++++++++++++++++++++++ ginka/vae_rnn/encoder.py | 55 +++++++++ ginka/vae_rnn/loss.py | 17 +++ ginka/vae_rnn/vae.py | 23 ++++ 6 files changed, 563 insertions(+) create mode 100644 ginka/train_vae.py create mode 100644 ginka/utils.py create mode 100644 ginka/vae_rnn/decoder.py create mode 100644 ginka/vae_rnn/encoder.py create mode 100644 ginka/vae_rnn/loss.py create mode 100644 ginka/vae_rnn/vae.py diff --git a/ginka/train_vae.py b/ginka/train_vae.py new file mode 100644 index 0000000..cec3143 --- /dev/null +++ b/ginka/train_vae.py @@ -0,0 +1,225 @@ +import argparse +import os +import sys +from datetime import datetime +import torch +import torch.nn.functional as F +import torch.optim as optim +import cv2 +import numpy as np +from torch_geometric.loader import DataLoader +from tqdm import tqdm +from .vae_rnn.vae import GinkaVAE +from .vae_rnn.loss import VAELoss +from .dataset import GinkaRNNDataset +from shared.image import matrix_to_image_cv + +# 手工标注标签定义(暂时不用): +# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, +# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 +# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 +# 24. 中心对称, 25. 部分对称, 26. 鱼骨 + +# 自动标注标签定义(暂时不用): +# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 +# 32. 平面塔, 33. 转换塔, 34. 道具塔 + +# 标量值定义: +# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 +# 1. 墙体密度,墙壁/地图面积 +# 2. 装饰密度,装饰数量/地图面积 +# 3. 门密度,门数量/地图面积 +# 4. 怪物密度,怪物数量/地图面积 +# 5. 资源密度,资源数量/地图面积 +# 6. 宝石密度,宝石数量/地图面积 +# 7. 血瓶密度,血瓶数量/地图面积 +# 8. 钥匙密度,钥匙数量/地图面积 +# 9. 道具密度,道具数量/地图面积 +# 10. 入口数量 +# 11. 机关门数量 +# 12. 咸鱼门数量(多层咸鱼门只算一个) + +# 图块定义: +# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), +# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 +# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 +# 10-12. 三种等级的红宝石 +# 13-15. 三种等级的蓝宝石 +# 16-18. 三种等级的绿宝石 +# 19-22. 四种等级的血瓶 +# 23-25. 三种等级的道具 +# 26-28. 三种等级的怪物 +# 29. 入口,不区分楼梯和箭头 + +BATCH_SIZE = 32 + +device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") +os.makedirs("result", exist_ok=True) +os.makedirs("result/vae", exist_ok=True) +os.makedirs("result/ginka_vae_img", exist_ok=True) + +disable_tqdm = not sys.stdout.isatty() + +def gt_prob(epoch: int, max_epoch: int) -> float: + progress = epoch / max_epoch + return max(1.2 * progress - 0.2, 0) + +def parse_arguments(): + parser = argparse.ArgumentParser(description="training codes") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--checkpoint", type=int, default=5) + parser.add_argument("--load_optim", type=bool, default=True) + args = parser.parse_args() + return args + +def train(): + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + + args = parse_arguments() + + vae = GinkaVAE(device).to(device) + + dataset = GinkaRNNDataset(args.train, device) + dataset_val = GinkaRNNDataset(args.validate, device) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) + + optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) + scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) + + criterion = VAELoss() + + # 用于生成图片 + tile_dict = dict() + for file in os.listdir('tiles'): + name = os.path.splitext(file)[0] + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + + if args.resume: + data_ginka = torch.load(args.state_ginka, map_location=device) + + vae.load_state_dict(data_ginka["model_state"], strict=False) + + if args.load_optim: + if data_ginka.get("optim_state") is not None: + optimizer_ginka.load_state_dict(data_ginka["optim_state"]) + + print("Train from loaded state.") + + for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): + loss_total = torch.Tensor([0]).to(device) + reco_loss_total = torch.Tensor([0]).to(device) + kl_loss_total = torch.Tensor([0]).to(device) + + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + target_map = batch["target_map"].to(device) + + fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs)) + + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05) + + loss.backward() + optimizer_ginka.step() + loss_total += loss.detach() + reco_loss_total += reco_loss.detach() + kl_loss_total += kl_loss.detach() + + avg_loss = loss_total.item() / len(dataloader) + avg_reco_loss = reco_loss_total.item() / len(dataloader) + avg_kl_loss = kl_loss_total.item() / len(dataloader) + tqdm.write( + f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco Loss: {avg_reco_loss:.6f} | " + + f"KL Loss: {avg_kl_loss:.6f} | LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" + ) + + scheduler_ginka.step() + + # 每若干轮输出一次图片,并保存检查点 + if (epoch + 1) % args.checkpoint == 0: + # 保存检查点 + torch.save({ + "model_state": vae.state_dict(), + "optim_state": optimizer_ginka.state_dict(), + }, f"result/rnn/ginka-{epoch + 1}.pth") + + val_loss_total = torch.Tensor([0]).to(device) + reco_loss_total = torch.Tensor([0]).to(device) + kl_loss_total = torch.Tensor([0]).to(device) + with torch.no_grad(): + idx = 0 + gap = 5 + color = (255, 255, 255) # 白色 + vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 + # 地图重建展示 + for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + target_map = batch["target_map"].to(device) + + fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs)) + + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05) + val_loss_total += loss.detach() + reco_loss_total += reco_loss.detach() + kl_loss_total += kl_loss.detach() + + fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + real_map = target_map.cpu().numpy() + real_img = matrix_to_image_cv(real_map[0], tile_dict) + img = np.block([[real_img], [vline], [fake_img]]) + cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img) + + idx += 1 + + # 随机采样 + for i in range(0, 8): + z = torch.randn(1, 32).to(device) + + fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1) + fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + + cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img) + + # 插值 + map1 = torch.LongTensor(dataset_val.data[0]["map"]).to(device).reshape(1, 13, 13) + map2 = torch.LongTensor(dataset_val.data[1]["map"]).to(device).reshape(1, 13, 13) + map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device) + map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device) + mu1, logvar1 = vae.encoder(map1_onehot) + mu2, logvar2 = vae.encoder(map2_onehot) + z1 = vae.reparameterize(mu1, logvar1) + z2 = vae.reparameterize(mu2, logvar2) + real_img1 = matrix_to_image_cv(map1[0], tile_dict) + real_img2 = matrix_to_image_cv(map2[0], tile_dict) + for t in torch.linspace(0, 1, 8): + z = z1 * (1 - t / 8) + z2 * t / 8 + fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1) + fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]]) + + cv2.imwrite(f"result/ginka_vae_img/{t}_linspace.png", img) + + avg_loss_val = val_loss_total.item() / len(dataloader_val) + avg_reco_loss = reco_loss_total.item() / len(dataloader_val) + avg_kl_loss = kl_loss_total.item() / len(dataloader_val) + tqdm.write( + f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + + f"Loss: {avg_loss_val:.6f} | Reco Loss: {avg_reco_loss:.6f} | " + + f"KL Loss: {avg_kl_loss:.6f}" + ) + + print("Train ended.") + torch.save({ + "model_state": vae.state_dict(), + }, f"result/ginka_rnn.pth") + + +if __name__ == "__main__": + torch.set_num_threads(4) + train() diff --git a/ginka/utils.py b/ginka/utils.py new file mode 100644 index 0000000..f14886e --- /dev/null +++ b/ginka/utils.py @@ -0,0 +1,7 @@ +import torch + +def print_memory(device, tag=""): + if torch.cuda.is_available(): + print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB") + else: + print("当前设备不支持 cuda.") \ No newline at end of file diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py new file mode 100644 index 0000000..a649829 --- /dev/null +++ b/ginka/vae_rnn/decoder.py @@ -0,0 +1,236 @@ +import time +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils import print_memory + +class GinkaMapPatch(nn.Module): + def __init__(self, tile_classes=32, width=13, height=13): + super().__init__() + + # 地图局部卷积,用于捕获局部结构信息 + + self.width = width + self.height = height + self.tile_classes = 32 + + self.patch_cnn = nn.Sequential( + nn.Conv2d(tile_classes + 1, 64, 3, padding=1), + nn.Dropout(0.2), + nn.BatchNorm2d(64), + nn.GELU(), + + nn.Conv2d(64, 128, 3), + nn.Dropout(0.2), + nn.BatchNorm2d(128), + nn.GELU(), + + nn.Flatten() + ) + self.fc = nn.Linear(128 * 3 * 3, 256) + + def forward(self, map: torch.Tensor, x: int, y: int): + """ + map: [B, H, W] + """ + B, H, W = map.shape + mask = torch.zeros([B, 5, 5]).to(map.device) + result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device) + left = x - 2 if x >= 2 else 0 + right = x + 3 if x < self.width - 2 else self.width + top = y - 4 if y >= 4 else 0 + bottom = y + 1 + + res_left = left - (x - 2) + res_right = right - (x + 3) + 5 + res_top = top - (y - 4) + res_bottom = 5 + + result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right] + # 没画到的地方要置为 0 + result[:, 4, 2] = 0 + result[:, 4, 3] = 0 + result[:, 4, 4] = 0 + mask[:, res_top:res_bottom, res_left:res_right] = 1 + mask[:, 4, 2] = 0 + mask[:, 4, 3] = 0 + mask[:, 4, 4] = 0 + masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device) + masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() + masked_result[:, 32] = mask + + feat = self.patch_cnn(masked_result) + feat = self.fc(feat) + return feat + +class GinkaTileEmbedding(nn.Module): + def __init__(self, tile_classes=32, embed_dim=256): + super().__init__() + + # 图块编码,上一次画的图块 + + self.embedding = nn.Embedding(tile_classes, embed_dim) + + def forward(self, tile: torch.Tensor): + return self.embedding(tile) + +class GinkaPosEmbedding(nn.Module): + def __init__(self, width=13, height=13, embed_dim=256): + super().__init__() + + # 位置编码 + + self.width = width + self.height = height + + self.row_embedding = nn.Embedding(width, embed_dim) + self.col_embedding = nn.Embedding(height, embed_dim) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + row = self.row_embedding(x).squeeze(1) + col = self.col_embedding(y).squeeze(1) + + return row, col + +class GinkaInputFusion(nn.Module): + def __init__(self, d_model=256): + super().__init__() + + # 使用 Transformer 进行信息整合 + + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True, + dropout=0.2 + ), + num_layers=4 + ) + + def forward( + self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, + row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor + ): + """ + tile_embed: [B, 256] + cond_vec: [B, 256] + row_embed: [B, 256] + col_embed: [B, 256] + patch_vec: [B, 256] + """ + vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) + feat = self.transformer(vec) + return feat[:, 0] + +class GinkaRNN(nn.Module): + def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512): + super().__init__() + + # GRU + self.gru = nn.GRUCell(input_dim, hidden_dim) + self.drop = nn.Dropout(0.2) + self.fc = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), + + nn.Linear(hidden_dim, tile_classes) + ) + + def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor): + """ + feat_fusion: [B, input_dim] + hidden: [B, hidden_dim] + """ + hidden = self.drop(self.gru(feat_fusion, hidden)) + logits = self.fc(hidden) + return logits, hidden + +class VAEDecoder(nn.Module): + def __init__(self, device: torch.device, start_tile=31, map_vec_dim=32, width=13, height=13): + super().__init__() + + self.device = device + self.width = width + self.height = height + self.start_tile = start_tile + + self.rnn_hidden = 512 + self.tile_classes = 32 + + # 模型结构 + self.map_vec_fc = nn.Sequential( + nn.Linear(map_vec_dim, 256) + ) + self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes) + self.pos_embedding = GinkaPosEmbedding() + self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes) + self.feat_fusion = GinkaInputFusion() + self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) + + def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0): + """ + map_vec: [B, vec_dim] + target_map: [B, H, W] + use_self: 是否使用自己生成的上一步结果执行下一步 + """ + B, C = map_vec.shape + + # 张量声明 + now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B) + + map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device) + output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device) + hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device) + + map_vec = self.map_vec_fc(map_vec) + + for y in range(0, self.height): + for x in range(0, self.width): + x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1) + y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1) + # 位置编码、图块编码、地图局部编码 + tile_embed = self.tile_embedding(now_tile) + row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) + use_self = random.random() < use_self_probility + map_patch = self.map_patch(map if use_self else target_map, x, y) + # 编码特征融合 + feat = self.feat_fusion(tile_embed, map_vec, row_embed, col_embed, map_patch) + # RNN 输出 + logits, h = self.rnn(feat, hidden) + # 处理输出 + output_logits[:, y, x] = logits[:] + hidden = h + tile_id = torch.argmax(logits, dim=1).detach() + map[:, y, x] = tile_id[:] + now_tile = tile_id if use_self else target_map[:, y, x].detach() + + return output_logits.permute(0, 3, 1, 2) + +if __name__ == "__main__": + device = torch.device("cpu") + + input = torch.randint(0, 32, [1, 13, 13]).to(device) + map_vec = torch.rand(1, 32).to(device) + + # 初始化模型 + model = VAEDecoder("cpu").to(device) + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + fake_logits, fake_map = model(map_vec, input, 0) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}") + print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}") + print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") + print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") + print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}") + print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}") + print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/encoder.py b/ginka/vae_rnn/encoder.py new file mode 100644 index 0000000..419b1ba --- /dev/null +++ b/ginka/vae_rnn/encoder.py @@ -0,0 +1,55 @@ +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils import print_memory + +class VAEEncoder(nn.Module): + def __init__(self, tile_classes=32, latent_dim=32): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(tile_classes, 64, 3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.Conv2d(64, 128, 3, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(), + + nn.Conv2d(128, 256, 3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(), + + nn.Flatten() + ) + self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim) + self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim) + + def forward(self, x): + h = self.conv(x) + mu = self.fc_mu(h) + logvar = self.fc_logvar(h) + return mu, logvar + +if __name__ == "__main__": + device = torch.device("cpu") + + input = torch.randint(0, 32, [1, 13, 13]).to(device) + input = F.one_hot(input, 32).permute(0, 3, 1, 2).float() + + # 初始化模型 + model = VAEEncoder().to(device) + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + mu, logvar = model(input) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") + print(f"CNN parameters: {sum(p.numel() for p in model.conv.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/loss.py b/ginka/vae_rnn/loss.py new file mode 100644 index 0000000..9595448 --- /dev/null +++ b/ginka/vae_rnn/loss.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +class VAELoss: + def __init__(self): + self.num_classes = 32 + + def vae_loss(self, logits, target, mu, logvar, beta=0.1): + # target: [B, 13, 13] + target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2) + recon_loss = F.cross_entropy(logits, target) + + kl_loss = -0.5 * torch.mean( + 1 + logvar - mu.pow(2) - logvar.exp() + ) + + return recon_loss + beta * kl_loss, recon_loss, kl_loss diff --git a/ginka/vae_rnn/vae.py b/ginka/vae_rnn/vae.py new file mode 100644 index 0000000..95da9ad --- /dev/null +++ b/ginka/vae_rnn/vae.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .encoder import VAEEncoder +from .decoder import VAEDecoder + +class GinkaVAE(nn.Module): + def __init__(self, device, tile_classes=32, latent_dim=32): + super().__init__() + self.encoder = VAEEncoder(tile_classes, latent_dim) + self.decoder = VAEDecoder(device) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, target_map: torch.Tensor, use_self_probility=0): + target = F.one_hot(target_map, num_classes=32).float().permute(0, 3, 1, 2) + mu, logvar = self.encoder(target) + z = self.reparameterize(mu, logvar) + logits = self.decoder(z, target_map, use_self_probility) + return logits, mu, logvar \ No newline at end of file