diff --git a/data/src/auto.ts b/data/src/auto.ts index 359acfa..61b5f08 100644 --- a/data/src/auto.ts +++ b/data/src/auto.ts @@ -253,8 +253,8 @@ const labelConfig: IAutoLabelConfig = { entry: 29 }, allowedSize: [[13, 13]], - allowUselessBranch: false, - maxWallDensityStd: 0.25, + allowUselessBranch: true, + maxWallDensityStd: 0.5, minEnemyRatio: 0.02, maxEnemyRatio: 0.3, minWallRatio: 0.1, @@ -269,9 +269,9 @@ const labelConfig: IAutoLabelConfig = { maxEntryCount: 4, ignoreIssues: true, customTowerFilter: info => { - // if (info.name !== 'Apeiria') { - // return false; - // } + if (info.name !== 'Apeiria') { + return false; + } // if (info.color !== TowerColor.Blue && info.color !== TowerColor.Green) { // return false; // } @@ -305,9 +305,9 @@ const labelConfig: IAutoLabelConfig = { if (ignoredFloor[floor.tower.name]?.includes(floor.mapId)) { return false; } - if (floor.tower.name === 'Apeiria') { - return Math.random() < 0.2; - } + // if (floor.tower.name === 'Apeiria') { + // return Math.random() < 0.2; + // } return true; } }; diff --git a/ginka/train_transformer_vae.py b/ginka/train_transformer_vae.py new file mode 100644 index 0000000..59258bd --- /dev/null +++ b/ginka/train_transformer_vae.py @@ -0,0 +1,277 @@ +import argparse +import os +import sys +import random +from datetime import datetime +import torch +import torch.nn.functional as F +import torch.optim as optim +import cv2 +import numpy as np +from torch_geometric.loader import DataLoader +from tqdm import tqdm +from .transformer_vae.vae import GinkaTransformerVAE +from .vae_rnn.loss import VAELoss +from .vae_rnn.scheduler import VAEScheduler +from .dataset import GinkaRNNDataset +from shared.image import matrix_to_image_cv + +# 手工标注标签定义(暂时不用): +# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, +# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 +# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 +# 24. 中心对称, 25. 部分对称, 26. 鱼骨 + +# 自动标注标签定义(暂时不用): +# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 +# 32. 平面塔, 33. 转换塔, 34. 道具塔 + +# 标量值定义: +# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 +# 1. 墙体密度,墙壁/地图面积 +# 2. 装饰密度,装饰数量/地图面积 +# 3. 门密度,门数量/地图面积 +# 4. 怪物密度,怪物数量/地图面积 +# 5. 资源密度,资源数量/地图面积 +# 6. 宝石密度,宝石数量/地图面积 +# 7. 血瓶密度,血瓶数量/地图面积 +# 8. 钥匙密度,钥匙数量/地图面积 +# 9. 道具密度,道具数量/地图面积 +# 10. 入口数量 +# 11. 机关门数量 +# 12. 咸鱼门数量(多层咸鱼门只算一个) + +# 图块定义: +# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), +# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 +# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 +# 10-12. 三种等级的红宝石 +# 13-15. 三种等级的蓝宝石 +# 16-18. 三种等级的绿宝石 +# 19-22. 四种等级的血瓶 +# 23-25. 三种等级的道具 +# 26-28. 三种等级的怪物 +# 29. 入口,不区分楼梯和箭头 + +BATCH_SIZE = 8 +LATENT_DIM = 48 +KL_BETA = 0.1 +SELF_GATE = 0.5 +GATE_EPOCH = 5 +VAL_BATCH_DIVIDER = 8 +PROB_STEP = 0.05 + +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.mps.is_available() + else "cpu" +) +os.makedirs("result", exist_ok=True) +os.makedirs("result/vae", exist_ok=True) +os.makedirs("result/ginka_vae_img", exist_ok=True) + +disable_tqdm = not sys.stdout.isatty() + +def parse_arguments(): + parser = argparse.ArgumentParser(description="training codes") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--checkpoint", type=int, default=5) + parser.add_argument("--load_optim", type=bool, default=True) + args = parser.parse_args() + return args + +def train(): + print(f"Using {device.type} to train model.") + + args = parse_arguments() + + vae = GinkaTransformerVAE(latent_dim=LATENT_DIM).to(device) + + dataset = GinkaRNNDataset(args.train, device) + dataset_val = GinkaRNNDataset(args.validate, device) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) + + optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) + # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 + scheduler_ginka = VAEScheduler( + optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6 + ) + + criterion = VAELoss() + + self_prob = 0 + prob_epochs = 0 + + # 用于生成图片 + tile_dict = dict() + for file in os.listdir('tiles'): + name = os.path.splitext(file)[0] + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + + if args.resume: + data_ginka = torch.load(args.state_ginka, map_location=device) + + vae.load_state_dict(data_ginka["model_state"], strict=False) + + if args.load_optim: + if data_ginka.get("optim_state") is not None: + optimizer_ginka.load_state_dict(data_ginka["optim_state"]) + + print("Train from loaded state.") + + for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): + loss_total = torch.Tensor([0]).to(device) + reco_loss_total = torch.Tensor([0]).to(device) + kl_loss_total = torch.Tensor([0]).to(device) + + vae.teacher_forcing() + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + target_map = batch["target_map"].to(device) + B, H, W = target_map.shape + input = target_map.view(B, H * W) + + optimizer_ginka.zero_grad() + fake_logits, mu, logvar = vae(input, self_prob) + + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA) + + loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) + optimizer_ginka.step() + loss_total += loss.detach() + reco_loss_total += reco_loss.detach() + kl_loss_total += kl_loss.detach() + + avg_loss = loss_total.item() / len(dataloader) + avg_reco_loss = reco_loss_total.item() / len(dataloader) + avg_kl_loss = kl_loss_total.item() / len(dataloader) + tqdm.write( + f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " + + f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}" + ) + + # 验证集 + # with torch.no_grad(): + # val_loss_total = torch.Tensor([0]).to(device) + # for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + # target_map = batch["target_map"].to(device) + + # fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) + + # loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) + # val_loss_total += loss.detach() + + # avg_loss_val = val_loss_total.item() / len(dataloader_val) + + # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 + if avg_loss < SELF_GATE: + prob_epochs += 1 + else: + prob_epochs = 0 + + if prob_epochs >= GATE_EPOCH and self_prob < 1: + self_prob += PROB_STEP + prob_epochs = 0 + if self_prob > 1: + self_prob = 1 + + self_prob = 1 + + scheduler_ginka.step(avg_loss, self_prob) + + # 每若干轮输出一次图片,并保存检查点 + if (epoch + 1) % 1 == 0: + # 保存检查点 + torch.save({ + "model_state": vae.state_dict(), + "optim_state": optimizer_ginka.state_dict(), + }, f"result/rnn/ginka-{epoch + 1}.pth") + + val_loss_total = torch.Tensor([0]).to(device) + val_reco_loss_total = torch.Tensor([0]).to(device) + val_kl_loss_total = torch.Tensor([0]).to(device) + vae.eval() + with torch.no_grad(): + idx = 0 + gap = 5 + color = (255, 255, 255) # 白色 + vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 + # 地图重建展示 + for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + target_map = batch["target_map"].to(device) + B, H, W = target_map.shape + input = target_map.view(B, H * W) + + fake_logits, mu, logvar = vae(input, self_prob) + + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA) + val_loss_total += loss.detach() + val_reco_loss_total += reco_loss.detach() + val_kl_loss_total += kl_loss.detach() + + fake_map = torch.argmax(fake_logits, dim=2).view(B, H, W).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + real_map = target_map.cpu().numpy() + real_img = matrix_to_image_cv(real_map[0], tile_dict) + img = np.block([[real_img], [vline], [fake_img]]) + cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img) + + idx += 1 + + # 随机采样 + for i in range(0, 8): + z = torch.randn(1, LATENT_DIM).to(device) + + vae.autoregressive() + fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) + fake_map = fake_logits.view(-1, 13, 13).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + + cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img) + + # 插值 + val_length = len(dataset_val.data) + index1 = random.randint(0, val_length - 1) + index2 = random.randint(0, val_length - 1) + map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).view(1, 169) + map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).view(1, 169) + mu1, logvar1 = vae.encoder(map1) + mu2, logvar2 = vae.encoder(map2) + z1 = vae.reparameterize(mu1, logvar1) + z2 = vae.reparameterize(mu2, logvar2) + real_img1 = matrix_to_image_cv(map1[0].view(13, 13).cpu().numpy(), tile_dict) + real_img2 = matrix_to_image_cv(map2[0].view(13, 13).cpu().numpy(), tile_dict) + i = 0 + for t in torch.linspace(0, 1, 8): + z = z1 * (1 - t / 8) + z2 * t / 8 + fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) + fake_map = fake_logits.view(-1, 13, 13).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]]) + + cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img) + i += 1 + + avg_loss_val = val_loss_total.item() / len(dataloader_val) + avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val) + avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val) + tqdm.write( + f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + + f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}" + ) + + print("Train ended.") + torch.save({ + "model_state": vae.state_dict(), + }, f"result/ginka_transformer.pth") + + +if __name__ == "__main__": + torch.set_num_threads(4) + train() diff --git a/ginka/transformer_vae/decoder.py b/ginka/transformer_vae/decoder.py new file mode 100644 index 0000000..b02b8de --- /dev/null +++ b/ginka/transformer_vae/decoder.py @@ -0,0 +1,105 @@ +import time +import torch +import torch.nn as nn +from ..utils import print_memory + +class GinkaTransformerDecoder(nn.Module): + def __init__(self, num_classes=32, dim_ff=256, nhead=4, num_layers=4, map_size=13*13): + super().__init__() + self.autoregressive = False + self.dim_ff = dim_ff + self.map_size = map_size + self.embedding = nn.Embedding(num_classes, dim_ff) + self.pos_embedding = nn.Embedding(map_size, dim_ff) + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), + num_layers=max(num_layers // 2, 1) + ) + self.decoder = nn.TransformerDecoder( + nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), + num_layers=num_layers + ) + self.fc = nn.Sequential( + nn.Linear(dim_ff, num_classes) + ) + + def forward(self, z: torch.Tensor, target_map: torch.Tensor): + # z: [B, dim_ff] + # target_map: [B, H * W] + # training output: [B, H * W, dim_ff] + # evaling output: [B, H * W] + B, L = target_map.shape + + memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff] + mask = torch.triu(torch.ones(L, L, dtype=torch.bool)).to(z.device) # [B, H * W, H * W] + + # when training, use teacher forcing + if not self.autoregressive: + map = self.embedding(target_map) + pos_embed = self.pos_embedding(torch.arange(L, dtype=torch.long).to(z.device)) + map = map + pos_embed # [B, H * W, dim_ff] + decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff] + output = self.fc(decoded) + return output + + # when evaling, use autoregressive generation + else: + output = torch.zeros([B, L], dtype=torch.int).to(z.device) + for idx in range(0, self.map_size): + embed = self.embedding(output) + pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device)) + map = embed + pos_embed # [B, H * W, dim_ff] + decoded = self.decoder(map, memory, tgt_mask=mask) + decoded = self.fc(decoded) # [B, H * W, dim_ff] + output[:, idx] = torch.argmax(decoded[:, idx, :], dim=1) + + return output + +class GinkaTransformerVAEDecoder(nn.Module): + def __init__( + self, latent_dim=32, num_classes=32, dim_ff=256, nhead=4, num_layers=4, + map_size=13*13 + ): + super().__init__() + self.map_size = map_size + self.input = nn.Sequential( + nn.Linear(latent_dim, dim_ff), + nn.Dropout(0.3), + nn.LayerNorm(dim_ff), + nn.ReLU(), + + nn.Linear(dim_ff, dim_ff) + ) + self.decoder = GinkaTransformerDecoder( + num_classes=num_classes, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers, map_size=map_size + ) + + def forward(self, z: torch.Tensor, map: torch.Tensor): + hidden = self.input(z) + output = self.decoder(hidden, map) + return output[:, 0:self.map_size] + +if __name__ == "__main__": + device = torch.device("cpu") + + input = torch.randn(1, 32).to(device) + map = torch.randint(0, 32, [1, 169]).to(device) + + # 初始化模型 + model = GinkaTransformerVAEDecoder().to(device) + model.eval() + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + output = model(input, map) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: output={output.shape}") + print(f"Input Embedding parameters: {sum(p.numel() for p in model.input.parameters())}") + print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/transformer_vae/encoder.py b/ginka/transformer_vae/encoder.py new file mode 100644 index 0000000..67f9ea8 --- /dev/null +++ b/ginka/transformer_vae/encoder.py @@ -0,0 +1,96 @@ +import time +import torch +import torch.nn as nn +from ..utils import print_memory + +class GinkaTransformerEncoder(nn.Module): + def __init__(self, dim_ff=256, nhead=4, num_layers=4): + super().__init__() + self.dim_ff = dim_ff + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), + num_layers=num_layers + ) + self.decoder = nn.TransformerDecoder( + nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), + num_layers=max(num_layers // 2, 1) + ) + + def forward(self, x: torch.Tensor): + # x: [B, H * W, S] + B, L, S = x.shape + first_token = torch.randn(B, 1, self.dim_ff).to(x.device) + x = self.encoder(x) + x = self.decoder(first_token, x) + return x.squeeze(1) + +class GinkaTransformerBottleneck(nn.Module): + def __init__(self, dim_ff=256, hidden_dim=512, latent_dim=32): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(dim_ff, hidden_dim), + nn.Dropout(0.3), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + ) + self.fc_mu = nn.Sequential( + nn.Linear(hidden_dim, latent_dim) + ) + self.fc_logvar = nn.Sequential( + nn.Linear(hidden_dim, latent_dim) + ) + + def forward(self, x): + # x: [B, dim_ff] + hidden = self.fc(x) + mu = self.fc_mu(hidden) + logvar = self.fc_logvar(hidden) + return mu, logvar + +class GinkaTransformerVAEEncoder(nn.Module): + def __init__( + self, num_classes=32, latent_dim=32, bottleneck_dim=512, dim_ff=256, + nhead=4, num_layers=4, map_size=13*13 + ): + super().__init__() + self.map_size = map_size + self.embedding = nn.Embedding(num_classes, dim_ff) + self.pos_embedding = nn.Embedding(map_size, dim_ff) + self.encoder = GinkaTransformerEncoder(dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) + self.bottleneck = GinkaTransformerBottleneck( + dim_ff=dim_ff, hidden_dim=bottleneck_dim, latent_dim=latent_dim + ) + + def forward(self, x: torch.Tensor): + # x: [B, map_size] + pos = self.pos_embedding(torch.arange(self.map_size, dtype=torch.long).to(x.device)) + x = self.embedding(x) + pos + x = self.encoder(x) + mu, logvar = self.bottleneck(x) + return mu, logvar + +if __name__ == "__main__": + device = torch.device("cpu") + + input = torch.randint(0, 32, [1, 169]).to(device) + + # 初始化模型 + model = GinkaTransformerVAEEncoder().to(device) + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + mu, logvar = model(input) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") + print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}") + print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") + print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") + print(f"bottleneck parameters: {sum(p.numel() for p in model.bottleneck.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") + diff --git a/ginka/transformer_vae/vae.py b/ginka/transformer_vae/vae.py new file mode 100644 index 0000000..89abb7a --- /dev/null +++ b/ginka/transformer_vae/vae.py @@ -0,0 +1,54 @@ +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from .encoder import GinkaTransformerVAEEncoder +from .decoder import GinkaTransformerVAEDecoder +from ..utils import print_memory + +class GinkaTransformerVAE(nn.Module): + def __init__(self, num_classes=32, latent_dim=32): + super().__init__() + self.encoder = GinkaTransformerVAEEncoder(num_classes=num_classes, latent_dim=latent_dim) + self.decoder = GinkaTransformerVAEDecoder(latent_dim=latent_dim) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def autoregressive(self): + self.decoder.decoder.autoregressive = True + + def teacher_forcing(self): + self.decoder.decoder.autoregressive = False + + def forward(self, target_map: torch.Tensor, use_self_probility=0): + # target_map: [B, H * W] + mu, logvar = self.encoder(target_map) # [B, latent_dim] + z = self.reparameterize(mu, logvar) + logits = self.decoder(z, target_map) # [B, H * W, num_classes] | [B, H * W] + return logits, mu, logvar + +if __name__ == "__main__": + device = torch.device("cpu") + + input = torch.randint(0, 32, [1, 169]).to(device) + + # 初始化模型 + model = GinkaTransformerVAE().to(device) + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + logits, mu, logvar = model(input) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}") + print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") + print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/loss.py b/ginka/vae_rnn/loss.py index 9595448..26ae649 100644 --- a/ginka/vae_rnn/loss.py +++ b/ginka/vae_rnn/loss.py @@ -6,8 +6,8 @@ class VAELoss: self.num_classes = 32 def vae_loss(self, logits, target, mu, logvar, beta=0.1): - # target: [B, 13, 13] - target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2) + # target: [B, 169] + target = F.one_hot(target, num_classes=self.num_classes).float() recon_loss = F.cross_entropy(logits, target) kl_loss = -0.5 * torch.mean(