mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 为 transformer 加入起始与结束 token
This commit is contained in:
parent
d2ecc9e8c8
commit
973434553a
4
data/shell.sh
Normal file
4
data/shell.sh
Normal file
@ -0,0 +1,4 @@
|
||||
# 自动处理塔信息
|
||||
pnpm auto "result.json" "F:/mota-ai/total data/towerinfo.json" "F:\mota-ai\total data/games"
|
||||
# 将数据按比例区分为训练集和数据集
|
||||
pnpm eval "ginka-dataset.json" "ginka-eval.json" "result.json" 0.02
|
||||
@ -240,17 +240,17 @@ const labelConfig: IAutoLabelConfig = {
|
||||
classes: {
|
||||
empty: 0,
|
||||
wall: 1,
|
||||
decoration: 2,
|
||||
commonDoors: [3],
|
||||
specialDoors: [6, 6],
|
||||
keys: [7],
|
||||
redGems: [10],
|
||||
blueGems: [13],
|
||||
greenGems: [16],
|
||||
potions: [19],
|
||||
items: [23],
|
||||
enemies: [26],
|
||||
entry: 29
|
||||
decoration: 16,
|
||||
commonDoors: [2],
|
||||
specialDoors: [2, 2],
|
||||
keys: [3],
|
||||
redGems: [4],
|
||||
blueGems: [5],
|
||||
greenGems: [6],
|
||||
potions: [7],
|
||||
items: [8],
|
||||
enemies: [9],
|
||||
entry: 10
|
||||
},
|
||||
allowedSize: [[13, 13]],
|
||||
allowUselessBranch: true,
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
// 基本图块定义
|
||||
export const emptyTiles = new Set([0]);
|
||||
export const wallTiles = new Set([1]);
|
||||
export const decorationTiles = new Set([2]);
|
||||
export const commonDoorTiles = new Set([3]);
|
||||
export const specialDoorTiles = new Set([6]);
|
||||
export const keyTiles = new Set([7]);
|
||||
export const redGemTiles = new Set([10]);
|
||||
export const blueGemTiles = new Set([13]);
|
||||
export const greenGemTiles = new Set([16]);
|
||||
export const potionTiles = new Set([19]);
|
||||
export const itemTiles = new Set([23]);
|
||||
export const enemyTiles = new Set([26]);
|
||||
export const entryTiles = new Set([29]);
|
||||
export const decorationTiles = new Set([16]);
|
||||
export const commonDoorTiles = new Set([2]);
|
||||
export const specialDoorTiles = new Set([2]);
|
||||
export const keyTiles = new Set([3]);
|
||||
export const redGemTiles = new Set([4]);
|
||||
export const blueGemTiles = new Set([5]);
|
||||
export const greenGemTiles = new Set([6]);
|
||||
export const potionTiles = new Set([7]);
|
||||
export const itemTiles = new Set([8]);
|
||||
export const enemyTiles = new Set([9]);
|
||||
export const entryTiles = new Set([10]);
|
||||
|
||||
// 组合图块定义
|
||||
export const doorTiles = commonDoorTiles.union(specialDoorTiles);
|
||||
|
||||
@ -42,16 +42,8 @@ from shared.image import matrix_to_image_cv
|
||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
||||
|
||||
# 图块定义:
|
||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
||||
# 10-12. 三种等级的红宝石
|
||||
# 13-15. 三种等级的蓝宝石
|
||||
# 16-18. 三种等级的绿宝石
|
||||
# 19-22. 四种等级的血瓶
|
||||
# 23-25. 三种等级的道具
|
||||
# 26-28. 三种等级的怪物
|
||||
# 29. 入口,不区分楼梯和箭头
|
||||
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
||||
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
|
||||
|
||||
BATCH_SIZE = 8
|
||||
LATENT_DIM = 48
|
||||
@ -60,6 +52,7 @@ SELF_GATE = 0.5
|
||||
GATE_EPOCH = 5
|
||||
VAL_BATCH_DIVIDER = 8
|
||||
PROB_STEP = 0.05
|
||||
NUM_CLASSES = 16
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -89,7 +82,7 @@ def train():
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
vae = GinkaTransformerVAE(latent_dim=LATENT_DIM).to(device)
|
||||
vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
|
||||
|
||||
dataset = GinkaRNNDataset(args.train, device)
|
||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||
|
||||
@ -10,7 +10,7 @@ class GinkaTransformerDecoder(nn.Module):
|
||||
self.dim_ff = dim_ff
|
||||
self.map_size = map_size
|
||||
self.embedding = nn.Embedding(num_classes, dim_ff)
|
||||
self.pos_embedding = nn.Embedding(map_size, dim_ff)
|
||||
self.pos_embedding = nn.Embedding(map_size + 1, dim_ff)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
|
||||
num_layers=max(num_layers // 2, 1)
|
||||
@ -31,12 +31,14 @@ class GinkaTransformerDecoder(nn.Module):
|
||||
B, L = target_map.shape
|
||||
|
||||
memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff]
|
||||
mask = torch.triu(torch.ones(L, L, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
|
||||
mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
|
||||
|
||||
# when training, use teacher forcing
|
||||
if not self.autoregressive:
|
||||
map = self.embedding(target_map)
|
||||
pos_embed = self.pos_embedding(torch.arange(L, dtype=torch.long).to(z.device))
|
||||
first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1)
|
||||
with_first = torch.cat([first_token, target_map], dim=1)
|
||||
map = self.embedding(with_first)
|
||||
pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device))
|
||||
map = map + pos_embed # [B, H * W, dim_ff]
|
||||
decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff]
|
||||
output = self.fc(decoded)
|
||||
@ -44,7 +46,7 @@ class GinkaTransformerDecoder(nn.Module):
|
||||
|
||||
# when evaling, use autoregressive generation
|
||||
else:
|
||||
output = torch.zeros([B, L], dtype=torch.int).to(z.device)
|
||||
output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device)
|
||||
for idx in range(0, self.map_size):
|
||||
embed = self.embedding(output)
|
||||
pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device))
|
||||
@ -77,7 +79,7 @@ class GinkaTransformerVAEDecoder(nn.Module):
|
||||
def forward(self, z: torch.Tensor, map: torch.Tensor):
|
||||
hidden = self.input(z)
|
||||
output = self.decoder(hidden, map)
|
||||
return output[:, 0:self.map_size]
|
||||
return output
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
@ -87,7 +89,6 @@ if __name__ == "__main__":
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaTransformerVAEDecoder().to(device)
|
||||
model.eval()
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
|
||||
@ -7,6 +7,8 @@ class VAELoss:
|
||||
|
||||
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
||||
# target: [B, 169]
|
||||
end_token = torch.tensor([15], dtype=torch.long).to(logits.device)
|
||||
target = torch.cat([target, end_token], dim=1)
|
||||
target = F.one_hot(target, num_classes=self.num_classes).float()
|
||||
recon_loss = F.cross_entropy(logits, target)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user