mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 编码器加入 transformer 整合,改进解码器
This commit is contained in:
parent
8caa37a144
commit
b7fe24ee4c
@ -124,6 +124,7 @@ def train():
|
|||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0)
|
||||||
optimizer_ginka.step()
|
optimizer_ginka.step()
|
||||||
loss_total += loss.detach()
|
loss_total += loss.detach()
|
||||||
reco_loss_total += reco_loss.detach()
|
reco_loss_total += reco_loss.detach()
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ..utils import print_memory
|
from ..utils import print_memory
|
||||||
|
|
||||||
class GinkaMapPatch(nn.Module):
|
class DecoderMapPatch(nn.Module):
|
||||||
def __init__(self, tile_classes=32, width=13, height=13):
|
def __init__(self, tile_classes=32, width=13, height=13):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -64,18 +64,16 @@ class GinkaMapPatch(nn.Module):
|
|||||||
feat = self.fc(feat)
|
feat = self.fc(feat)
|
||||||
return feat
|
return feat
|
||||||
|
|
||||||
class GinkaTileEmbedding(nn.Module):
|
class DecoderTileEmbedding(nn.Module):
|
||||||
def __init__(self, tile_classes=32, embed_dim=256):
|
def __init__(self, tile_classes=32, embed_dim=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 图块编码,上一次画的图块
|
# 图块编码,上一次画的图块
|
||||||
|
|
||||||
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
||||||
|
|
||||||
def forward(self, tile: torch.Tensor):
|
def forward(self, tile: torch.Tensor):
|
||||||
return self.embedding(tile)
|
return self.embedding(tile)
|
||||||
|
|
||||||
class GinkaPosEmbedding(nn.Module):
|
class DecoderPosEmbedding(nn.Module):
|
||||||
def __init__(self, width=13, height=13, embed_dim=256):
|
def __init__(self, width=13, height=13, embed_dim=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -86,17 +84,13 @@ class GinkaPosEmbedding(nn.Module):
|
|||||||
|
|
||||||
self.row_embedding = nn.Embedding(height, embed_dim)
|
self.row_embedding = nn.Embedding(height, embed_dim)
|
||||||
self.col_embedding = nn.Embedding(width, embed_dim)
|
self.col_embedding = nn.Embedding(width, embed_dim)
|
||||||
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
row = self.row_embedding(y)
|
row = self.row_embedding(y)
|
||||||
col = self.col_embedding(x)
|
col = self.col_embedding(x)
|
||||||
embed = torch.cat([row, col], dim=2)
|
return row, col
|
||||||
fused = self.fusion(embed)
|
|
||||||
|
|
||||||
return fused
|
|
||||||
|
|
||||||
class GinkaInputFusion(nn.Module):
|
class DecoderInputFusion(nn.Module):
|
||||||
def __init__(self, d_model=256):
|
def __init__(self, d_model=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -107,24 +101,31 @@ class GinkaInputFusion(nn.Module):
|
|||||||
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
|
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
|
||||||
dropout=0.2
|
dropout=0.2
|
||||||
),
|
),
|
||||||
num_layers=3
|
num_layers=2
|
||||||
)
|
)
|
||||||
|
self.norm = nn.LayerNorm(d_model)
|
||||||
|
self.fusion = nn.Linear(d_model * 2, d_model)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
||||||
pos_embed: torch.Tensor, patch_vec: torch.Tensor
|
col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
tile_embed: [B, 256]
|
tile_embed: [B, 256]
|
||||||
cond_vec: [B, 256]
|
cond_vec: [B, 256]
|
||||||
pos_embed: [B, 256]
|
col_embed: [B, 256]
|
||||||
|
row_embed: [B, 256]
|
||||||
patch_vec: [B, 256]
|
patch_vec: [B, 256]
|
||||||
"""
|
"""
|
||||||
vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1)
|
vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1)
|
||||||
feat = self.transformer(vec)
|
feat = self.norm(self.transformer(vec))
|
||||||
return feat[:, 0]
|
mean = torch.mean(feat, dim=1)
|
||||||
|
max = torch.max(feat, dim=1).values
|
||||||
|
hidden = torch.cat([mean, max], dim=1)
|
||||||
|
fused = self.fusion(hidden)
|
||||||
|
return fused
|
||||||
|
|
||||||
class GinkaRNN(nn.Module):
|
class DecoderRNN(nn.Module):
|
||||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -162,13 +163,16 @@ class VAEDecoder(nn.Module):
|
|||||||
|
|
||||||
# 模型结构
|
# 模型结构
|
||||||
self.map_vec_fc = nn.Sequential(
|
self.map_vec_fc = nn.Sequential(
|
||||||
nn.Linear(map_vec_dim, 256)
|
nn.Linear(map_vec_dim, 128),
|
||||||
|
nn.LayerNorm(128),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(128, 256)
|
||||||
)
|
)
|
||||||
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
|
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
|
||||||
self.pos_embedding = GinkaPosEmbedding()
|
self.pos_embedding = DecoderPosEmbedding()
|
||||||
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
|
self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes)
|
||||||
self.feat_fusion = GinkaInputFusion()
|
self.feat_fusion = DecoderInputFusion()
|
||||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||||
|
|
||||||
self.col_list = []
|
self.col_list = []
|
||||||
self.row_list = []
|
self.row_list = []
|
||||||
@ -181,7 +185,7 @@ class VAEDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
map_vec: [B, vec_dim]
|
map_vec: [B, vec_dim]
|
||||||
target_map: [B, H, W]
|
target_map: [B, H, W]
|
||||||
use_self: 是否使用自己生成的上一步结果执行下一步
|
use_self_probility: 使用自己生成的上一步结果执行下一步的概率
|
||||||
"""
|
"""
|
||||||
B, C = map_vec.shape
|
B, C = map_vec.shape
|
||||||
|
|
||||||
@ -194,7 +198,7 @@ class VAEDecoder(nn.Module):
|
|||||||
|
|
||||||
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
||||||
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
||||||
pos_embed = self.pos_embedding(col_list, row_list)
|
col_embed, row_embed = self.pos_embedding(col_list, row_list)
|
||||||
|
|
||||||
map_vec = self.map_vec_fc(map_vec)
|
map_vec = self.map_vec_fc(map_vec)
|
||||||
|
|
||||||
@ -206,7 +210,7 @@ class VAEDecoder(nn.Module):
|
|||||||
use_self = random.random() < use_self_probility
|
use_self = random.random() < use_self_probility
|
||||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||||
# 编码特征融合
|
# 编码特征融合
|
||||||
feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch)
|
feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch)
|
||||||
# RNN 输出
|
# RNN 输出
|
||||||
logits, h = self.rnn(feat, hidden)
|
logits, h = self.rnn(feat, hidden)
|
||||||
# 处理输出
|
# 处理输出
|
||||||
|
|||||||
@ -7,9 +7,21 @@ from ..utils import print_memory
|
|||||||
class EncoderEmbedding(nn.Module):
|
class EncoderEmbedding(nn.Module):
|
||||||
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
|
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tile_embedding = nn.Embedding(tile_classes, hidden_dim)
|
self.tile_embedding = nn.Sequential(
|
||||||
self.col_embedding = nn.Embedding(width, hidden_dim)
|
nn.Embedding(tile_classes, hidden_dim),
|
||||||
self.row_embedding = nn.Embedding(height, hidden_dim)
|
nn.LayerNorm(hidden_dim),
|
||||||
|
nn.GELU()
|
||||||
|
)
|
||||||
|
self.col_embedding = nn.Sequential(
|
||||||
|
nn.Embedding(width, hidden_dim),
|
||||||
|
nn.LayerNorm(hidden_dim),
|
||||||
|
nn.GELU()
|
||||||
|
)
|
||||||
|
self.row_embedding = nn.Sequential(
|
||||||
|
nn.Embedding(height, hidden_dim),
|
||||||
|
nn.LayerNorm(hidden_dim),
|
||||||
|
nn.GELU()
|
||||||
|
)
|
||||||
self.fusion = nn.Linear(hidden_dim * 3, output_dim)
|
self.fusion = nn.Linear(hidden_dim * 3, output_dim)
|
||||||
|
|
||||||
def forward(self, tile, x, y):
|
def forward(self, tile, x, y):
|
||||||
@ -43,6 +55,25 @@ class EncoderGRU(nn.Module):
|
|||||||
hidden = self.drop(self.gru(feat, hidden))
|
hidden = self.drop(self.gru(feat, hidden))
|
||||||
logits = self.fc(hidden)
|
logits = self.fc(hidden)
|
||||||
return logits, hidden
|
return logits, hidden
|
||||||
|
|
||||||
|
class EncoderFusion(nn.Module):
|
||||||
|
def __init__(self, d_model=256):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.transformer = nn.TransformerEncoder(
|
||||||
|
nn.TransformerEncoderLayer(
|
||||||
|
d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True
|
||||||
|
),
|
||||||
|
num_layers=2
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, logits):
|
||||||
|
x = self.norm(self.transformer(logits))
|
||||||
|
h_mean = torch.mean(x, dim=1)
|
||||||
|
h_max = torch.max(x, dim=1).values
|
||||||
|
h = torch.cat([h_mean, h_max], dim=1)
|
||||||
|
return h
|
||||||
|
|
||||||
class VAEEncoder(nn.Module):
|
class VAEEncoder(nn.Module):
|
||||||
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
|
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
|
||||||
@ -54,6 +85,7 @@ class VAEEncoder(nn.Module):
|
|||||||
|
|
||||||
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
||||||
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
||||||
|
self.fusion = EncoderFusion(256)
|
||||||
self.fc_mu = nn.Linear(512, latent_dim)
|
self.fc_mu = nn.Linear(512, latent_dim)
|
||||||
self.fc_logvar = nn.Linear(512, latent_dim)
|
self.fc_logvar = nn.Linear(512, latent_dim)
|
||||||
|
|
||||||
@ -79,9 +111,8 @@ class VAEEncoder(nn.Module):
|
|||||||
logits, h = self.rnn(embed[:, idx], hidden)
|
logits, h = self.rnn(embed[:, idx], hidden)
|
||||||
hidden = h
|
hidden = h
|
||||||
output[:, idx] = logits
|
output[:, idx] = logits
|
||||||
h_mean = torch.mean(output, dim=1)
|
|
||||||
h_max = torch.max(output, dim=1).values
|
h = self.fusion(output)
|
||||||
h = torch.cat([h_mean, h_max], dim=1)
|
|
||||||
mu = self.fc_mu(h)
|
mu = self.fc_mu(h)
|
||||||
logvar = self.fc_logvar(h)
|
logvar = self.fc_logvar(h)
|
||||||
return mu, logvar
|
return mu, logvar
|
||||||
@ -107,4 +138,5 @@ if __name__ == "__main__":
|
|||||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
||||||
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
|
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
|
||||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
||||||
|
print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}")
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from .encoder import VAEEncoder
|
from .encoder import VAEEncoder
|
||||||
from .decoder import VAEDecoder
|
from .decoder import VAEDecoder
|
||||||
|
from ..utils import print_memory
|
||||||
|
|
||||||
class GinkaVAE(nn.Module):
|
class GinkaVAE(nn.Module):
|
||||||
def __init__(self, device, tile_classes=32, latent_dim=32):
|
def __init__(self, device, tile_classes=32, latent_dim=32):
|
||||||
@ -19,4 +21,27 @@ class GinkaVAE(nn.Module):
|
|||||||
mu, logvar = self.encoder(target_map)
|
mu, logvar = self.encoder(target_map)
|
||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
logits = self.decoder(z, target_map, use_self_probility)
|
logits = self.decoder(z, target_map, use_self_probility)
|
||||||
return logits, mu, logvar
|
return logits, mu, logvar
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
model = GinkaVAE(device).to(device)
|
||||||
|
|
||||||
|
print_memory("初始化后")
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
start = time.perf_counter()
|
||||||
|
logits, mu, logvar = model(input)
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
print_memory("前向传播后")
|
||||||
|
|
||||||
|
print(f"推理耗时: {end - start}")
|
||||||
|
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
|
||||||
|
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
|
||||||
|
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
|
||||||
|
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user