feat: vae编码器换为 rnn

This commit is contained in:
unanmed 2026-01-20 16:31:13 +08:00
parent 169a514dd1
commit dd6a043487
2 changed files with 84 additions and 29 deletions

View File

@ -84,12 +84,12 @@ class GinkaPosEmbedding(nn.Module):
self.width = width
self.height = height
self.row_embedding = nn.Embedding(width, embed_dim)
self.col_embedding = nn.Embedding(height, embed_dim)
self.row_embedding = nn.Embedding(height, embed_dim)
self.col_embedding = nn.Embedding(width, embed_dim)
def forward(self, x: torch.Tensor, y: torch.Tensor):
row = self.row_embedding(x).squeeze(1)
col = self.col_embedding(y).squeeze(1)
row = self.row_embedding(y).squeeze(1)
col = self.col_embedding(x).squeeze(1)
return row, col
@ -104,7 +104,7 @@ class GinkaInputFusion(nn.Module):
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
dropout=0.2
),
num_layers=4
num_layers=3
)
def forward(
@ -220,13 +220,13 @@ if __name__ == "__main__":
# 前向传播
start = time.perf_counter()
fake_logits, fake_map = model(map_vec, input, 0)
fake_logits = model(map_vec, input, 0)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}")
print(f"输出形状: fake_logits={fake_logits.shape}")
print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")

View File

@ -4,29 +4,84 @@ import torch.nn as nn
import torch.nn.functional as F
from ..utils import print_memory
class VAEEncoder(nn.Module):
def __init__(self, tile_classes=32, latent_dim=32):
class EncoderEmbedding(nn.Module):
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(tile_classes, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
self.tile_embedding = nn.Embedding(tile_classes, hidden_dim)
self.col_embedding = nn.Embedding(width, hidden_dim)
self.row_embedding = nn.Embedding(height, hidden_dim)
self.fusion = nn.Linear(hidden_dim * 3, output_dim)
def forward(self, tile, x, y):
tile_embed = self.tile_embedding(tile)
col_embed = self.col_embedding(x)
row_embed = self.row_embedding(y)
embed = torch.cat([tile_embed, col_embed, row_embed], dim=2)
fused = self.fusion(embed)
return fused
class EncoderGRU(nn.Module):
def __init__(self, input_dim=256, hidden_dim=512, output_dim=256):
super().__init__()
# GRU
self.gru = nn.GRUCell(input_dim, hidden_dim)
self.drop = nn.Dropout(0.1)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Flatten()
nn.Linear(hidden_dim, output_dim)
)
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
def forward(self, feat: torch.Tensor, hidden: torch.Tensor):
"""
feat: [B, input_dim]
hidden: [B, hidden_dim]
"""
hidden = self.drop(self.gru(feat, hidden))
logits = self.fc(hidden)
return logits, hidden
class VAEEncoder(nn.Module):
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
super().__init__()
self.device = device
self.rnn_hidden = 512
self.logits_dim = 256
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
self.fc_mu = nn.Linear(512, latent_dim)
self.fc_logvar = nn.Linear(512, latent_dim)
self.col_list = []
self.row_list = []
for y in range(0, height):
for x in range(0, width):
self.col_list.append(x)
self.row_list.append(y)
def forward(self, x):
h = self.conv(x)
def forward(self, x: torch.Tensor):
B, H, W = x.shape
map = torch.flatten(x, start_dim=1)
hidden = torch.zeros(B, self.rnn_hidden).to(self.device)
output = torch.zeros(B, H * W, self.logits_dim).to(self.device)
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
embed = self.embedding(map, col_list, row_list)
for idx in range(0, len(self.col_list)):
logits, h = self.rnn(embed[:, idx], hidden)
hidden = h
output[:, idx] = logits
h_mean = torch.mean(output, dim=1)
h_max = torch.max(output, dim=1).values
h = torch.cat([h_mean, h_max], dim=1)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
@ -35,10 +90,9 @@ if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
input = F.one_hot(input, 32).permute(0, 3, 1, 2).float()
# 初始化模型
model = VAEEncoder().to(device)
model = VAEEncoder(device).to(device)
print_memory("初始化后")
@ -51,5 +105,6 @@ if __name__ == "__main__":
print(f"推理耗时: {end - start}")
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
print(f"CNN parameters: {sum(p.numel() for p in model.conv.parameters())}")
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")