feat: 为 transformer 加入起始与结束 token

This commit is contained in:
unanmed 2026-03-10 18:12:42 +08:00
parent d2ecc9e8c8
commit 973434553a
6 changed files with 40 additions and 40 deletions

4
data/shell.sh Normal file
View File

@ -0,0 +1,4 @@
# 自动处理塔信息
pnpm auto "result.json" "F:/mota-ai/total data/towerinfo.json" "F:\mota-ai\total data/games"
# 将数据按比例区分为训练集和数据集
pnpm eval "ginka-dataset.json" "ginka-eval.json" "result.json" 0.02

View File

@ -240,17 +240,17 @@ const labelConfig: IAutoLabelConfig = {
classes: {
empty: 0,
wall: 1,
decoration: 2,
commonDoors: [3],
specialDoors: [6, 6],
keys: [7],
redGems: [10],
blueGems: [13],
greenGems: [16],
potions: [19],
items: [23],
enemies: [26],
entry: 29
decoration: 16,
commonDoors: [2],
specialDoors: [2, 2],
keys: [3],
redGems: [4],
blueGems: [5],
greenGems: [6],
potions: [7],
items: [8],
enemies: [9],
entry: 10
},
allowedSize: [[13, 13]],
allowUselessBranch: true,

View File

@ -1,17 +1,17 @@
// 基本图块定义
export const emptyTiles = new Set([0]);
export const wallTiles = new Set([1]);
export const decorationTiles = new Set([2]);
export const commonDoorTiles = new Set([3]);
export const specialDoorTiles = new Set([6]);
export const keyTiles = new Set([7]);
export const redGemTiles = new Set([10]);
export const blueGemTiles = new Set([13]);
export const greenGemTiles = new Set([16]);
export const potionTiles = new Set([19]);
export const itemTiles = new Set([23]);
export const enemyTiles = new Set([26]);
export const entryTiles = new Set([29]);
export const decorationTiles = new Set([16]);
export const commonDoorTiles = new Set([2]);
export const specialDoorTiles = new Set([2]);
export const keyTiles = new Set([3]);
export const redGemTiles = new Set([4]);
export const blueGemTiles = new Set([5]);
export const greenGemTiles = new Set([6]);
export const potionTiles = new Set([7]);
export const itemTiles = new Set([8]);
export const enemyTiles = new Set([9]);
export const entryTiles = new Set([10]);
// 组合图块定义
export const doorTiles = commonDoorTiles.union(specialDoorTiles);

View File

@ -42,16 +42,8 @@ from shared.image import matrix_to_image_cv
# 12. 咸鱼门数量(多层咸鱼门只算一个)
# 图块定义:
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
# 10-12. 三种等级的红宝石
# 13-15. 三种等级的蓝宝石
# 16-18. 三种等级的绿宝石
# 19-22. 四种等级的血瓶
# 23-25. 三种等级的道具
# 26-28. 三种等级的怪物
# 29. 入口,不区分楼梯和箭头
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
BATCH_SIZE = 8
LATENT_DIM = 48
@ -60,6 +52,7 @@ SELF_GATE = 0.5
GATE_EPOCH = 5
VAL_BATCH_DIVIDER = 8
PROB_STEP = 0.05
NUM_CLASSES = 16
device = torch.device(
"cuda:1" if torch.cuda.is_available()
@ -89,7 +82,7 @@ def train():
args = parse_arguments()
vae = GinkaTransformerVAE(latent_dim=LATENT_DIM).to(device)
vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)

View File

@ -10,7 +10,7 @@ class GinkaTransformerDecoder(nn.Module):
self.dim_ff = dim_ff
self.map_size = map_size
self.embedding = nn.Embedding(num_classes, dim_ff)
self.pos_embedding = nn.Embedding(map_size, dim_ff)
self.pos_embedding = nn.Embedding(map_size + 1, dim_ff)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
num_layers=max(num_layers // 2, 1)
@ -31,12 +31,14 @@ class GinkaTransformerDecoder(nn.Module):
B, L = target_map.shape
memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff]
mask = torch.triu(torch.ones(L, L, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
# when training, use teacher forcing
if not self.autoregressive:
map = self.embedding(target_map)
pos_embed = self.pos_embedding(torch.arange(L, dtype=torch.long).to(z.device))
first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1)
with_first = torch.cat([first_token, target_map], dim=1)
map = self.embedding(with_first)
pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device))
map = map + pos_embed # [B, H * W, dim_ff]
decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff]
output = self.fc(decoded)
@ -44,7 +46,7 @@ class GinkaTransformerDecoder(nn.Module):
# when evaling, use autoregressive generation
else:
output = torch.zeros([B, L], dtype=torch.int).to(z.device)
output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device)
for idx in range(0, self.map_size):
embed = self.embedding(output)
pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device))
@ -77,7 +79,7 @@ class GinkaTransformerVAEDecoder(nn.Module):
def forward(self, z: torch.Tensor, map: torch.Tensor):
hidden = self.input(z)
output = self.decoder(hidden, map)
return output[:, 0:self.map_size]
return output
if __name__ == "__main__":
device = torch.device("cpu")
@ -87,7 +89,6 @@ if __name__ == "__main__":
# 初始化模型
model = GinkaTransformerVAEDecoder().to(device)
model.eval()
print_memory("初始化后")

View File

@ -7,6 +7,8 @@ class VAELoss:
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
# target: [B, 169]
end_token = torch.tensor([15], dtype=torch.long).to(logits.device)
target = torch.cat([target, end_token], dim=1)
target = F.one_hot(target, num_classes=self.num_classes).float()
recon_loss = F.cross_entropy(logits, target)