mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 编码器加入 transformer 整合,改进解码器
This commit is contained in:
parent
8caa37a144
commit
b7fe24ee4c
@ -124,6 +124,7 @@ def train():
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0)
|
||||
optimizer_ginka.step()
|
||||
loss_total += loss.detach()
|
||||
reco_loss_total += reco_loss.detach()
|
||||
|
||||
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..utils import print_memory
|
||||
|
||||
class GinkaMapPatch(nn.Module):
|
||||
class DecoderMapPatch(nn.Module):
|
||||
def __init__(self, tile_classes=32, width=13, height=13):
|
||||
super().__init__()
|
||||
|
||||
@ -64,18 +64,16 @@ class GinkaMapPatch(nn.Module):
|
||||
feat = self.fc(feat)
|
||||
return feat
|
||||
|
||||
class GinkaTileEmbedding(nn.Module):
|
||||
class DecoderTileEmbedding(nn.Module):
|
||||
def __init__(self, tile_classes=32, embed_dim=256):
|
||||
super().__init__()
|
||||
|
||||
# 图块编码,上一次画的图块
|
||||
|
||||
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
||||
|
||||
def forward(self, tile: torch.Tensor):
|
||||
return self.embedding(tile)
|
||||
|
||||
class GinkaPosEmbedding(nn.Module):
|
||||
class DecoderPosEmbedding(nn.Module):
|
||||
def __init__(self, width=13, height=13, embed_dim=256):
|
||||
super().__init__()
|
||||
|
||||
@ -86,17 +84,13 @@ class GinkaPosEmbedding(nn.Module):
|
||||
|
||||
self.row_embedding = nn.Embedding(height, embed_dim)
|
||||
self.col_embedding = nn.Embedding(width, embed_dim)
|
||||
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
row = self.row_embedding(y)
|
||||
col = self.col_embedding(x)
|
||||
embed = torch.cat([row, col], dim=2)
|
||||
fused = self.fusion(embed)
|
||||
|
||||
return fused
|
||||
return row, col
|
||||
|
||||
class GinkaInputFusion(nn.Module):
|
||||
class DecoderInputFusion(nn.Module):
|
||||
def __init__(self, d_model=256):
|
||||
super().__init__()
|
||||
|
||||
@ -107,24 +101,31 @@ class GinkaInputFusion(nn.Module):
|
||||
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
|
||||
dropout=0.2
|
||||
),
|
||||
num_layers=3
|
||||
num_layers=2
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.fusion = nn.Linear(d_model * 2, d_model)
|
||||
|
||||
def forward(
|
||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
||||
pos_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
):
|
||||
"""
|
||||
tile_embed: [B, 256]
|
||||
cond_vec: [B, 256]
|
||||
pos_embed: [B, 256]
|
||||
col_embed: [B, 256]
|
||||
row_embed: [B, 256]
|
||||
patch_vec: [B, 256]
|
||||
"""
|
||||
vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1)
|
||||
feat = self.transformer(vec)
|
||||
return feat[:, 0]
|
||||
vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1)
|
||||
feat = self.norm(self.transformer(vec))
|
||||
mean = torch.mean(feat, dim=1)
|
||||
max = torch.max(feat, dim=1).values
|
||||
hidden = torch.cat([mean, max], dim=1)
|
||||
fused = self.fusion(hidden)
|
||||
return fused
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
class DecoderRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
||||
super().__init__()
|
||||
|
||||
@ -162,13 +163,16 @@ class VAEDecoder(nn.Module):
|
||||
|
||||
# 模型结构
|
||||
self.map_vec_fc = nn.Sequential(
|
||||
nn.Linear(map_vec_dim, 256)
|
||||
nn.Linear(map_vec_dim, 128),
|
||||
nn.LayerNorm(128),
|
||||
nn.GELU(),
|
||||
nn.Linear(128, 256)
|
||||
)
|
||||
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
|
||||
self.pos_embedding = GinkaPosEmbedding()
|
||||
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
|
||||
self.feat_fusion = GinkaInputFusion()
|
||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
|
||||
self.pos_embedding = DecoderPosEmbedding()
|
||||
self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes)
|
||||
self.feat_fusion = DecoderInputFusion()
|
||||
self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||
|
||||
self.col_list = []
|
||||
self.row_list = []
|
||||
@ -181,7 +185,7 @@ class VAEDecoder(nn.Module):
|
||||
"""
|
||||
map_vec: [B, vec_dim]
|
||||
target_map: [B, H, W]
|
||||
use_self: 是否使用自己生成的上一步结果执行下一步
|
||||
use_self_probility: 使用自己生成的上一步结果执行下一步的概率
|
||||
"""
|
||||
B, C = map_vec.shape
|
||||
|
||||
@ -194,7 +198,7 @@ class VAEDecoder(nn.Module):
|
||||
|
||||
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
||||
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
||||
pos_embed = self.pos_embedding(col_list, row_list)
|
||||
col_embed, row_embed = self.pos_embedding(col_list, row_list)
|
||||
|
||||
map_vec = self.map_vec_fc(map_vec)
|
||||
|
||||
@ -206,7 +210,7 @@ class VAEDecoder(nn.Module):
|
||||
use_self = random.random() < use_self_probility
|
||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||
# 编码特征融合
|
||||
feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch)
|
||||
feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch)
|
||||
# RNN 输出
|
||||
logits, h = self.rnn(feat, hidden)
|
||||
# 处理输出
|
||||
|
||||
@ -7,9 +7,21 @@ from ..utils import print_memory
|
||||
class EncoderEmbedding(nn.Module):
|
||||
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
|
||||
super().__init__()
|
||||
self.tile_embedding = nn.Embedding(tile_classes, hidden_dim)
|
||||
self.col_embedding = nn.Embedding(width, hidden_dim)
|
||||
self.row_embedding = nn.Embedding(height, hidden_dim)
|
||||
self.tile_embedding = nn.Sequential(
|
||||
nn.Embedding(tile_classes, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU()
|
||||
)
|
||||
self.col_embedding = nn.Sequential(
|
||||
nn.Embedding(width, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU()
|
||||
)
|
||||
self.row_embedding = nn.Sequential(
|
||||
nn.Embedding(height, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU()
|
||||
)
|
||||
self.fusion = nn.Linear(hidden_dim * 3, output_dim)
|
||||
|
||||
def forward(self, tile, x, y):
|
||||
@ -43,6 +55,25 @@ class EncoderGRU(nn.Module):
|
||||
hidden = self.drop(self.gru(feat, hidden))
|
||||
logits = self.fc(hidden)
|
||||
return logits, hidden
|
||||
|
||||
class EncoderFusion(nn.Module):
|
||||
def __init__(self, d_model=256):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True
|
||||
),
|
||||
num_layers=2
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, logits):
|
||||
x = self.norm(self.transformer(logits))
|
||||
h_mean = torch.mean(x, dim=1)
|
||||
h_max = torch.max(x, dim=1).values
|
||||
h = torch.cat([h_mean, h_max], dim=1)
|
||||
return h
|
||||
|
||||
class VAEEncoder(nn.Module):
|
||||
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
|
||||
@ -54,6 +85,7 @@ class VAEEncoder(nn.Module):
|
||||
|
||||
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
||||
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
||||
self.fusion = EncoderFusion(256)
|
||||
self.fc_mu = nn.Linear(512, latent_dim)
|
||||
self.fc_logvar = nn.Linear(512, latent_dim)
|
||||
|
||||
@ -79,9 +111,8 @@ class VAEEncoder(nn.Module):
|
||||
logits, h = self.rnn(embed[:, idx], hidden)
|
||||
hidden = h
|
||||
output[:, idx] = logits
|
||||
h_mean = torch.mean(output, dim=1)
|
||||
h_max = torch.max(output, dim=1).values
|
||||
h = torch.cat([h_mean, h_max], dim=1)
|
||||
|
||||
h = self.fusion(output)
|
||||
mu = self.fc_mu(h)
|
||||
logvar = self.fc_logvar(h)
|
||||
return mu, logvar
|
||||
@ -107,4 +138,5 @@ if __name__ == "__main__":
|
||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
||||
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
|
||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
||||
print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .encoder import VAEEncoder
|
||||
from .decoder import VAEDecoder
|
||||
from ..utils import print_memory
|
||||
|
||||
class GinkaVAE(nn.Module):
|
||||
def __init__(self, device, tile_classes=32, latent_dim=32):
|
||||
@ -19,4 +21,27 @@ class GinkaVAE(nn.Module):
|
||||
mu, logvar = self.encoder(target_map)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
logits = self.decoder(z, target_map, use_self_probility)
|
||||
return logits, mu, logvar
|
||||
return logits, mu, logvar
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaVAE(device).to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
logits, mu, logvar = model(input)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
|
||||
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
|
||||
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user