mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 05:01:41 +08:00
feat: 为 transformer 加入起始与结束 token
This commit is contained in:
parent
d2ecc9e8c8
commit
973434553a
4
data/shell.sh
Normal file
4
data/shell.sh
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# 自动处理塔信息
|
||||||
|
pnpm auto "result.json" "F:/mota-ai/total data/towerinfo.json" "F:\mota-ai\total data/games"
|
||||||
|
# 将数据按比例区分为训练集和数据集
|
||||||
|
pnpm eval "ginka-dataset.json" "ginka-eval.json" "result.json" 0.02
|
||||||
@ -240,17 +240,17 @@ const labelConfig: IAutoLabelConfig = {
|
|||||||
classes: {
|
classes: {
|
||||||
empty: 0,
|
empty: 0,
|
||||||
wall: 1,
|
wall: 1,
|
||||||
decoration: 2,
|
decoration: 16,
|
||||||
commonDoors: [3],
|
commonDoors: [2],
|
||||||
specialDoors: [6, 6],
|
specialDoors: [2, 2],
|
||||||
keys: [7],
|
keys: [3],
|
||||||
redGems: [10],
|
redGems: [4],
|
||||||
blueGems: [13],
|
blueGems: [5],
|
||||||
greenGems: [16],
|
greenGems: [6],
|
||||||
potions: [19],
|
potions: [7],
|
||||||
items: [23],
|
items: [8],
|
||||||
enemies: [26],
|
enemies: [9],
|
||||||
entry: 29
|
entry: 10
|
||||||
},
|
},
|
||||||
allowedSize: [[13, 13]],
|
allowedSize: [[13, 13]],
|
||||||
allowUselessBranch: true,
|
allowUselessBranch: true,
|
||||||
|
|||||||
@ -1,17 +1,17 @@
|
|||||||
// 基本图块定义
|
// 基本图块定义
|
||||||
export const emptyTiles = new Set([0]);
|
export const emptyTiles = new Set([0]);
|
||||||
export const wallTiles = new Set([1]);
|
export const wallTiles = new Set([1]);
|
||||||
export const decorationTiles = new Set([2]);
|
export const decorationTiles = new Set([16]);
|
||||||
export const commonDoorTiles = new Set([3]);
|
export const commonDoorTiles = new Set([2]);
|
||||||
export const specialDoorTiles = new Set([6]);
|
export const specialDoorTiles = new Set([2]);
|
||||||
export const keyTiles = new Set([7]);
|
export const keyTiles = new Set([3]);
|
||||||
export const redGemTiles = new Set([10]);
|
export const redGemTiles = new Set([4]);
|
||||||
export const blueGemTiles = new Set([13]);
|
export const blueGemTiles = new Set([5]);
|
||||||
export const greenGemTiles = new Set([16]);
|
export const greenGemTiles = new Set([6]);
|
||||||
export const potionTiles = new Set([19]);
|
export const potionTiles = new Set([7]);
|
||||||
export const itemTiles = new Set([23]);
|
export const itemTiles = new Set([8]);
|
||||||
export const enemyTiles = new Set([26]);
|
export const enemyTiles = new Set([9]);
|
||||||
export const entryTiles = new Set([29]);
|
export const entryTiles = new Set([10]);
|
||||||
|
|
||||||
// 组合图块定义
|
// 组合图块定义
|
||||||
export const doorTiles = commonDoorTiles.union(specialDoorTiles);
|
export const doorTiles = commonDoorTiles.union(specialDoorTiles);
|
||||||
|
|||||||
@ -42,16 +42,8 @@ from shared.image import matrix_to_image_cv
|
|||||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
||||||
|
|
||||||
# 图块定义:
|
# 图块定义:
|
||||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
||||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
|
||||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
|
||||||
# 10-12. 三种等级的红宝石
|
|
||||||
# 13-15. 三种等级的蓝宝石
|
|
||||||
# 16-18. 三种等级的绿宝石
|
|
||||||
# 19-22. 四种等级的血瓶
|
|
||||||
# 23-25. 三种等级的道具
|
|
||||||
# 26-28. 三种等级的怪物
|
|
||||||
# 29. 入口,不区分楼梯和箭头
|
|
||||||
|
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
LATENT_DIM = 48
|
LATENT_DIM = 48
|
||||||
@ -60,6 +52,7 @@ SELF_GATE = 0.5
|
|||||||
GATE_EPOCH = 5
|
GATE_EPOCH = 5
|
||||||
VAL_BATCH_DIVIDER = 8
|
VAL_BATCH_DIVIDER = 8
|
||||||
PROB_STEP = 0.05
|
PROB_STEP = 0.05
|
||||||
|
NUM_CLASSES = 16
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -89,7 +82,7 @@ def train():
|
|||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
vae = GinkaTransformerVAE(latent_dim=LATENT_DIM).to(device)
|
vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
|
||||||
|
|
||||||
dataset = GinkaRNNDataset(args.train, device)
|
dataset = GinkaRNNDataset(args.train, device)
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class GinkaTransformerDecoder(nn.Module):
|
|||||||
self.dim_ff = dim_ff
|
self.dim_ff = dim_ff
|
||||||
self.map_size = map_size
|
self.map_size = map_size
|
||||||
self.embedding = nn.Embedding(num_classes, dim_ff)
|
self.embedding = nn.Embedding(num_classes, dim_ff)
|
||||||
self.pos_embedding = nn.Embedding(map_size, dim_ff)
|
self.pos_embedding = nn.Embedding(map_size + 1, dim_ff)
|
||||||
self.encoder = nn.TransformerEncoder(
|
self.encoder = nn.TransformerEncoder(
|
||||||
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
|
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
|
||||||
num_layers=max(num_layers // 2, 1)
|
num_layers=max(num_layers // 2, 1)
|
||||||
@ -31,12 +31,14 @@ class GinkaTransformerDecoder(nn.Module):
|
|||||||
B, L = target_map.shape
|
B, L = target_map.shape
|
||||||
|
|
||||||
memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff]
|
memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff]
|
||||||
mask = torch.triu(torch.ones(L, L, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
|
mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
|
||||||
|
|
||||||
# when training, use teacher forcing
|
# when training, use teacher forcing
|
||||||
if not self.autoregressive:
|
if not self.autoregressive:
|
||||||
map = self.embedding(target_map)
|
first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1)
|
||||||
pos_embed = self.pos_embedding(torch.arange(L, dtype=torch.long).to(z.device))
|
with_first = torch.cat([first_token, target_map], dim=1)
|
||||||
|
map = self.embedding(with_first)
|
||||||
|
pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device))
|
||||||
map = map + pos_embed # [B, H * W, dim_ff]
|
map = map + pos_embed # [B, H * W, dim_ff]
|
||||||
decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff]
|
decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff]
|
||||||
output = self.fc(decoded)
|
output = self.fc(decoded)
|
||||||
@ -44,7 +46,7 @@ class GinkaTransformerDecoder(nn.Module):
|
|||||||
|
|
||||||
# when evaling, use autoregressive generation
|
# when evaling, use autoregressive generation
|
||||||
else:
|
else:
|
||||||
output = torch.zeros([B, L], dtype=torch.int).to(z.device)
|
output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device)
|
||||||
for idx in range(0, self.map_size):
|
for idx in range(0, self.map_size):
|
||||||
embed = self.embedding(output)
|
embed = self.embedding(output)
|
||||||
pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device))
|
pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device))
|
||||||
@ -77,7 +79,7 @@ class GinkaTransformerVAEDecoder(nn.Module):
|
|||||||
def forward(self, z: torch.Tensor, map: torch.Tensor):
|
def forward(self, z: torch.Tensor, map: torch.Tensor):
|
||||||
hidden = self.input(z)
|
hidden = self.input(z)
|
||||||
output = self.decoder(hidden, map)
|
output = self.decoder(hidden, map)
|
||||||
return output[:, 0:self.map_size]
|
return output
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -87,7 +89,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = GinkaTransformerVAEDecoder().to(device)
|
model = GinkaTransformerVAEDecoder().to(device)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
print_memory("初始化后")
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,8 @@ class VAELoss:
|
|||||||
|
|
||||||
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
||||||
# target: [B, 169]
|
# target: [B, 169]
|
||||||
|
end_token = torch.tensor([15], dtype=torch.long).to(logits.device)
|
||||||
|
target = torch.cat([target, end_token], dim=1)
|
||||||
target = F.one_hot(target, num_classes=self.num_classes).float()
|
target = F.one_hot(target, num_classes=self.num_classes).float()
|
||||||
recon_loss = F.cross_entropy(logits, target)
|
recon_loss = F.cross_entropy(logits, target)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user