From 973434553a22083386f30d55c3c0a974eaf32ce0 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 10 Mar 2026 18:12:42 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=BA=20transformer=20=E5=8A=A0?= =?UTF-8?q?=E5=85=A5=E8=B5=B7=E5=A7=8B=E4=B8=8E=E7=BB=93=E6=9D=9F=20token?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/shell.sh | 4 ++++ data/src/auto.ts | 22 +++++++++++----------- data/src/shared.ts | 22 +++++++++++----------- ginka/train_transformer_vae.py | 15 ++++----------- ginka/transformer_vae/decoder.py | 15 ++++++++------- ginka/vae_rnn/loss.py | 2 ++ 6 files changed, 40 insertions(+), 40 deletions(-) create mode 100644 data/shell.sh diff --git a/data/shell.sh b/data/shell.sh new file mode 100644 index 0000000..f896cf6 --- /dev/null +++ b/data/shell.sh @@ -0,0 +1,4 @@ +# 自动处理塔信息 +pnpm auto "result.json" "F:/mota-ai/total data/towerinfo.json" "F:\mota-ai\total data/games" +# 将数据按比例区分为训练集和数据集 +pnpm eval "ginka-dataset.json" "ginka-eval.json" "result.json" 0.02 diff --git a/data/src/auto.ts b/data/src/auto.ts index 61b5f08..c581a26 100644 --- a/data/src/auto.ts +++ b/data/src/auto.ts @@ -240,17 +240,17 @@ const labelConfig: IAutoLabelConfig = { classes: { empty: 0, wall: 1, - decoration: 2, - commonDoors: [3], - specialDoors: [6, 6], - keys: [7], - redGems: [10], - blueGems: [13], - greenGems: [16], - potions: [19], - items: [23], - enemies: [26], - entry: 29 + decoration: 16, + commonDoors: [2], + specialDoors: [2, 2], + keys: [3], + redGems: [4], + blueGems: [5], + greenGems: [6], + potions: [7], + items: [8], + enemies: [9], + entry: 10 }, allowedSize: [[13, 13]], allowUselessBranch: true, diff --git a/data/src/shared.ts b/data/src/shared.ts index 4d38968..da6ca03 100644 --- a/data/src/shared.ts +++ b/data/src/shared.ts @@ -1,17 +1,17 @@ // 基本图块定义 export const emptyTiles = new Set([0]); export const wallTiles = new Set([1]); -export const decorationTiles = new Set([2]); -export const commonDoorTiles = new Set([3]); -export const specialDoorTiles = new Set([6]); -export const keyTiles = new Set([7]); -export const redGemTiles = new Set([10]); -export const blueGemTiles = new Set([13]); -export const greenGemTiles = new Set([16]); -export const potionTiles = new Set([19]); -export const itemTiles = new Set([23]); -export const enemyTiles = new Set([26]); -export const entryTiles = new Set([29]); +export const decorationTiles = new Set([16]); +export const commonDoorTiles = new Set([2]); +export const specialDoorTiles = new Set([2]); +export const keyTiles = new Set([3]); +export const redGemTiles = new Set([4]); +export const blueGemTiles = new Set([5]); +export const greenGemTiles = new Set([6]); +export const potionTiles = new Set([7]); +export const itemTiles = new Set([8]); +export const enemyTiles = new Set([9]); +export const entryTiles = new Set([10]); // 组合图块定义 export const doorTiles = commonDoorTiles.union(specialDoorTiles); diff --git a/ginka/train_transformer_vae.py b/ginka/train_transformer_vae.py index 59258bd..84da62f 100644 --- a/ginka/train_transformer_vae.py +++ b/ginka/train_transformer_vae.py @@ -42,16 +42,8 @@ from shared.image import matrix_to_image_cv # 12. 咸鱼门数量(多层咸鱼门只算一个) # 图块定义: -# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), -# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 -# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 -# 10-12. 三种等级的红宝石 -# 13-15. 三种等级的蓝宝石 -# 16-18. 三种等级的绿宝石 -# 19-22. 四种等级的血瓶 -# 23-25. 三种等级的道具 -# 26-28. 三种等级的怪物 -# 29. 入口,不区分楼梯和箭头 +# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶 +# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token BATCH_SIZE = 8 LATENT_DIM = 48 @@ -60,6 +52,7 @@ SELF_GATE = 0.5 GATE_EPOCH = 5 VAL_BATCH_DIVIDER = 8 PROB_STEP = 0.05 +NUM_CLASSES = 16 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -89,7 +82,7 @@ def train(): args = parse_arguments() - vae = GinkaTransformerVAE(latent_dim=LATENT_DIM).to(device) + vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device) dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) diff --git a/ginka/transformer_vae/decoder.py b/ginka/transformer_vae/decoder.py index b02b8de..cb720d0 100644 --- a/ginka/transformer_vae/decoder.py +++ b/ginka/transformer_vae/decoder.py @@ -10,7 +10,7 @@ class GinkaTransformerDecoder(nn.Module): self.dim_ff = dim_ff self.map_size = map_size self.embedding = nn.Embedding(num_classes, dim_ff) - self.pos_embedding = nn.Embedding(map_size, dim_ff) + self.pos_embedding = nn.Embedding(map_size + 1, dim_ff) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), num_layers=max(num_layers // 2, 1) @@ -31,12 +31,14 @@ class GinkaTransformerDecoder(nn.Module): B, L = target_map.shape memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff] - mask = torch.triu(torch.ones(L, L, dtype=torch.bool)).to(z.device) # [B, H * W, H * W] + mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W] # when training, use teacher forcing if not self.autoregressive: - map = self.embedding(target_map) - pos_embed = self.pos_embedding(torch.arange(L, dtype=torch.long).to(z.device)) + first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1) + with_first = torch.cat([first_token, target_map], dim=1) + map = self.embedding(with_first) + pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device)) map = map + pos_embed # [B, H * W, dim_ff] decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff] output = self.fc(decoded) @@ -44,7 +46,7 @@ class GinkaTransformerDecoder(nn.Module): # when evaling, use autoregressive generation else: - output = torch.zeros([B, L], dtype=torch.int).to(z.device) + output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device) for idx in range(0, self.map_size): embed = self.embedding(output) pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device)) @@ -77,7 +79,7 @@ class GinkaTransformerVAEDecoder(nn.Module): def forward(self, z: torch.Tensor, map: torch.Tensor): hidden = self.input(z) output = self.decoder(hidden, map) - return output[:, 0:self.map_size] + return output if __name__ == "__main__": device = torch.device("cpu") @@ -87,7 +89,6 @@ if __name__ == "__main__": # 初始化模型 model = GinkaTransformerVAEDecoder().to(device) - model.eval() print_memory("初始化后") diff --git a/ginka/vae_rnn/loss.py b/ginka/vae_rnn/loss.py index 26ae649..4ebebee 100644 --- a/ginka/vae_rnn/loss.py +++ b/ginka/vae_rnn/loss.py @@ -7,6 +7,8 @@ class VAELoss: def vae_loss(self, logits, target, mu, logvar, beta=0.1): # target: [B, 169] + end_token = torch.tensor([15], dtype=torch.long).to(logits.device) + target = torch.cat([target, end_token], dim=1) target = F.one_hot(target, num_classes=self.num_classes).float() recon_loss = F.cross_entropy(logits, target)