feat: 编码器加入 transformer 整合,改进解码器

This commit is contained in:
unanmed 2026-01-20 23:48:25 +08:00
parent 8caa37a144
commit b7fe24ee4c
4 changed files with 96 additions and 34 deletions

View File

@ -124,6 +124,7 @@ def train():
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05) loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0)
optimizer_ginka.step() optimizer_ginka.step()
loss_total += loss.detach() loss_total += loss.detach()
reco_loss_total += reco_loss.detach() reco_loss_total += reco_loss.detach()

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import print_memory from ..utils import print_memory
class GinkaMapPatch(nn.Module): class DecoderMapPatch(nn.Module):
def __init__(self, tile_classes=32, width=13, height=13): def __init__(self, tile_classes=32, width=13, height=13):
super().__init__() super().__init__()
@ -64,18 +64,16 @@ class GinkaMapPatch(nn.Module):
feat = self.fc(feat) feat = self.fc(feat)
return feat return feat
class GinkaTileEmbedding(nn.Module): class DecoderTileEmbedding(nn.Module):
def __init__(self, tile_classes=32, embed_dim=256): def __init__(self, tile_classes=32, embed_dim=256):
super().__init__() super().__init__()
# 图块编码,上一次画的图块 # 图块编码,上一次画的图块
self.embedding = nn.Embedding(tile_classes, embed_dim) self.embedding = nn.Embedding(tile_classes, embed_dim)
def forward(self, tile: torch.Tensor): def forward(self, tile: torch.Tensor):
return self.embedding(tile) return self.embedding(tile)
class GinkaPosEmbedding(nn.Module): class DecoderPosEmbedding(nn.Module):
def __init__(self, width=13, height=13, embed_dim=256): def __init__(self, width=13, height=13, embed_dim=256):
super().__init__() super().__init__()
@ -86,17 +84,13 @@ class GinkaPosEmbedding(nn.Module):
self.row_embedding = nn.Embedding(height, embed_dim) self.row_embedding = nn.Embedding(height, embed_dim)
self.col_embedding = nn.Embedding(width, embed_dim) self.col_embedding = nn.Embedding(width, embed_dim)
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
def forward(self, x: torch.Tensor, y: torch.Tensor): def forward(self, x: torch.Tensor, y: torch.Tensor):
row = self.row_embedding(y) row = self.row_embedding(y)
col = self.col_embedding(x) col = self.col_embedding(x)
embed = torch.cat([row, col], dim=2) return row, col
fused = self.fusion(embed)
return fused
class GinkaInputFusion(nn.Module): class DecoderInputFusion(nn.Module):
def __init__(self, d_model=256): def __init__(self, d_model=256):
super().__init__() super().__init__()
@ -107,24 +101,31 @@ class GinkaInputFusion(nn.Module):
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True, d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
dropout=0.2 dropout=0.2
), ),
num_layers=3 num_layers=2
) )
self.norm = nn.LayerNorm(d_model)
self.fusion = nn.Linear(d_model * 2, d_model)
def forward( def forward(
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
pos_embed: torch.Tensor, patch_vec: torch.Tensor col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor
): ):
""" """
tile_embed: [B, 256] tile_embed: [B, 256]
cond_vec: [B, 256] cond_vec: [B, 256]
pos_embed: [B, 256] col_embed: [B, 256]
row_embed: [B, 256]
patch_vec: [B, 256] patch_vec: [B, 256]
""" """
vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1) vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1)
feat = self.transformer(vec) feat = self.norm(self.transformer(vec))
return feat[:, 0] mean = torch.mean(feat, dim=1)
max = torch.max(feat, dim=1).values
hidden = torch.cat([mean, max], dim=1)
fused = self.fusion(hidden)
return fused
class GinkaRNN(nn.Module): class DecoderRNN(nn.Module):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512): def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
super().__init__() super().__init__()
@ -162,13 +163,16 @@ class VAEDecoder(nn.Module):
# 模型结构 # 模型结构
self.map_vec_fc = nn.Sequential( self.map_vec_fc = nn.Sequential(
nn.Linear(map_vec_dim, 256) nn.Linear(map_vec_dim, 128),
nn.LayerNorm(128),
nn.GELU(),
nn.Linear(128, 256)
) )
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes) self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
self.pos_embedding = GinkaPosEmbedding() self.pos_embedding = DecoderPosEmbedding()
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes) self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes)
self.feat_fusion = GinkaInputFusion() self.feat_fusion = DecoderInputFusion()
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
self.col_list = [] self.col_list = []
self.row_list = [] self.row_list = []
@ -181,7 +185,7 @@ class VAEDecoder(nn.Module):
""" """
map_vec: [B, vec_dim] map_vec: [B, vec_dim]
target_map: [B, H, W] target_map: [B, H, W]
use_self: 是否使用自己生成的上一步结果执行下一步 use_self_probility: 使用自己生成的上一步结果执行下一步的概率
""" """
B, C = map_vec.shape B, C = map_vec.shape
@ -194,7 +198,7 @@ class VAEDecoder(nn.Module):
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1) col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1) row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
pos_embed = self.pos_embedding(col_list, row_list) col_embed, row_embed = self.pos_embedding(col_list, row_list)
map_vec = self.map_vec_fc(map_vec) map_vec = self.map_vec_fc(map_vec)
@ -206,7 +210,7 @@ class VAEDecoder(nn.Module):
use_self = random.random() < use_self_probility use_self = random.random() < use_self_probility
map_patch = self.map_patch(map if use_self else target_map, x, y) map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合 # 编码特征融合
feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch) feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch)
# RNN 输出 # RNN 输出
logits, h = self.rnn(feat, hidden) logits, h = self.rnn(feat, hidden)
# 处理输出 # 处理输出

View File

@ -7,9 +7,21 @@ from ..utils import print_memory
class EncoderEmbedding(nn.Module): class EncoderEmbedding(nn.Module):
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256): def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
super().__init__() super().__init__()
self.tile_embedding = nn.Embedding(tile_classes, hidden_dim) self.tile_embedding = nn.Sequential(
self.col_embedding = nn.Embedding(width, hidden_dim) nn.Embedding(tile_classes, hidden_dim),
self.row_embedding = nn.Embedding(height, hidden_dim) nn.LayerNorm(hidden_dim),
nn.GELU()
)
self.col_embedding = nn.Sequential(
nn.Embedding(width, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
self.row_embedding = nn.Sequential(
nn.Embedding(height, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
self.fusion = nn.Linear(hidden_dim * 3, output_dim) self.fusion = nn.Linear(hidden_dim * 3, output_dim)
def forward(self, tile, x, y): def forward(self, tile, x, y):
@ -43,6 +55,25 @@ class EncoderGRU(nn.Module):
hidden = self.drop(self.gru(feat, hidden)) hidden = self.drop(self.gru(feat, hidden))
logits = self.fc(hidden) logits = self.fc(hidden)
return logits, hidden return logits, hidden
class EncoderFusion(nn.Module):
def __init__(self, d_model=256):
super().__init__()
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True
),
num_layers=2
)
self.norm = nn.LayerNorm(d_model)
def forward(self, logits):
x = self.norm(self.transformer(logits))
h_mean = torch.mean(x, dim=1)
h_max = torch.max(x, dim=1).values
h = torch.cat([h_mean, h_max], dim=1)
return h
class VAEEncoder(nn.Module): class VAEEncoder(nn.Module):
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13): def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
@ -54,6 +85,7 @@ class VAEEncoder(nn.Module):
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256) self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim) self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
self.fusion = EncoderFusion(256)
self.fc_mu = nn.Linear(512, latent_dim) self.fc_mu = nn.Linear(512, latent_dim)
self.fc_logvar = nn.Linear(512, latent_dim) self.fc_logvar = nn.Linear(512, latent_dim)
@ -79,9 +111,8 @@ class VAEEncoder(nn.Module):
logits, h = self.rnn(embed[:, idx], hidden) logits, h = self.rnn(embed[:, idx], hidden)
hidden = h hidden = h
output[:, idx] = logits output[:, idx] = logits
h_mean = torch.mean(output, dim=1)
h_max = torch.max(output, dim=1).values h = self.fusion(output)
h = torch.cat([h_mean, h_max], dim=1)
mu = self.fc_mu(h) mu = self.fc_mu(h)
logvar = self.fc_logvar(h) logvar = self.fc_logvar(h)
return mu, logvar return mu, logvar
@ -107,4 +138,5 @@ if __name__ == "__main__":
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}") print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,8 +1,10 @@
import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .encoder import VAEEncoder from .encoder import VAEEncoder
from .decoder import VAEDecoder from .decoder import VAEDecoder
from ..utils import print_memory
class GinkaVAE(nn.Module): class GinkaVAE(nn.Module):
def __init__(self, device, tile_classes=32, latent_dim=32): def __init__(self, device, tile_classes=32, latent_dim=32):
@ -19,4 +21,27 @@ class GinkaVAE(nn.Module):
mu, logvar = self.encoder(target_map) mu, logvar = self.encoder(target_map)
z = self.reparameterize(mu, logvar) z = self.reparameterize(mu, logvar)
logits = self.decoder(z, target_map, use_self_probility) logits = self.decoder(z, target_map, use_self_probility)
return logits, mu, logvar return logits, mu, logvar
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
# 初始化模型
model = GinkaVAE(device).to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
logits, mu, logvar = model(input)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")