import time import torch import torch.nn as nn import torch.nn.functional as F from .encoder import GinkaTransformerVAEEncoder from .decoder import GinkaTransformerVAEDecoder from ..utils import print_memory class GinkaTransformerVAE(nn.Module): def __init__(self, num_classes=32, latent_dim=32): super().__init__() self.encoder = GinkaTransformerVAEEncoder(num_classes=num_classes, latent_dim=latent_dim) self.decoder = GinkaTransformerVAEDecoder(latent_dim=latent_dim) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def autoregressive(self): self.decoder.decoder.autoregressive = True def teacher_forcing(self): self.decoder.decoder.autoregressive = False def forward(self, target_map: torch.Tensor, use_self_probility=0): # target_map: [B, H * W] mu, logvar = self.encoder(target_map) # [B, latent_dim] z = self.reparameterize(mu, logvar) logits = self.decoder(z, target_map) # [B, H * W, num_classes] | [B, H * W] return logits, mu, logvar if __name__ == "__main__": device = torch.device("cpu") input = torch.randint(0, 32, [1, 169]).to(device) # 初始化模型 model = GinkaTransformerVAE().to(device) print_memory("初始化后") # 前向传播 start = time.perf_counter() logits, mu, logvar = model(input) end = time.perf_counter() print_memory("前向传播后") print(f"推理耗时: {end - start}") print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}") print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")