ginka-generator/ginka/transformer/vae.py
2026-03-10 23:06:23 +08:00

55 lines
1.9 KiB
Python

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from .encoder import GinkaTransformerVAEEncoder
from .decoder import GinkaTransformerVAEDecoder
from ..utils import print_memory
class GinkaTransformerVAE(nn.Module):
def __init__(self, num_classes=32, latent_dim=32):
super().__init__()
self.encoder = GinkaTransformerVAEEncoder(num_classes=num_classes, latent_dim=latent_dim)
self.decoder = GinkaTransformerVAEDecoder(latent_dim=latent_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def autoregressive(self):
self.decoder.decoder.autoregressive = True
def teacher_forcing(self):
self.decoder.decoder.autoregressive = False
def forward(self, target_map: torch.Tensor, use_self_probility=0):
# target_map: [B, H * W]
mu, logvar = self.encoder(target_map) # [B, latent_dim]
z = self.reparameterize(mu, logvar)
logits = self.decoder(z, target_map) # [B, H * W, num_classes] | [B, H * W]
return logits, mu, logvar
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 169]).to(device)
# 初始化模型
model = GinkaTransformerVAE().to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
logits, mu, logvar = model(input)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")