ginka-generator/ginka/heatmap/cond.py

49 lines
1.8 KiB
Python

import torch
import torch.nn as nn
class HeatmapCond(nn.Module):
def __init__(self, T=100, embed_dim=128, heatmap_dim=8, output_dim=128):
super().__init__()
self.time_embedding = nn.Embedding(T, embed_dim)
self.conv1 = nn.Sequential(
nn.Conv2d(heatmap_dim, output_dim // 4, 3, padding=1, padding_mode='replicate'),
nn.BatchNorm2d(output_dim // 4),
nn.GELU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(output_dim // 4, output_dim // 2, 3, padding=1, padding_mode='replicate'),
nn.BatchNorm2d(output_dim // 2),
nn.GELU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(output_dim // 2, output_dim, 3, padding=1, padding_mode='replicate')
)
self.fc1 = nn.Sequential(
nn.Linear(embed_dim, output_dim // 4),
nn.Dropout(0.3),
nn.LayerNorm(output_dim // 4),
nn.GELU()
)
self.fc2 = nn.Sequential(
nn.Linear(embed_dim, output_dim // 2),
nn.Dropout(0.3),
nn.LayerNorm(output_dim // 2),
nn.GELU()
)
self.fc3 = nn.Sequential(
nn.Linear(embed_dim, output_dim),
nn.Dropout(0.3),
nn.LayerNorm(output_dim),
nn.GELU()
)
def forward(self, heatmap: torch.Tensor, t: torch.Tensor):
# heatmap: [B, C, H, W]
# t: [B]
t_embed = self.time_embedding(t)
x = self.conv1(heatmap) + self.fc1(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
x = self.conv2(x) + self.fc2(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
x = self.conv3(x) + self.fc3(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
return x