mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class HeatmapCond(nn.Module):
|
|
def __init__(self, T=100, embed_dim=128, heatmap_dim=8, output_dim=128):
|
|
super().__init__()
|
|
self.time_embedding = nn.Embedding(T, embed_dim)
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(heatmap_dim, output_dim // 4, 3, padding=1, padding_mode='replicate'),
|
|
nn.BatchNorm2d(output_dim // 4),
|
|
nn.GELU()
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(output_dim // 4, output_dim // 2, 3, padding=1, padding_mode='replicate'),
|
|
nn.BatchNorm2d(output_dim // 2),
|
|
nn.GELU()
|
|
)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(output_dim // 2, output_dim, 3, padding=1, padding_mode='replicate')
|
|
)
|
|
|
|
self.fc1 = nn.Sequential(
|
|
nn.Linear(embed_dim, output_dim // 4),
|
|
nn.Dropout(0.3),
|
|
nn.LayerNorm(output_dim // 4),
|
|
nn.GELU()
|
|
)
|
|
self.fc2 = nn.Sequential(
|
|
nn.Linear(embed_dim, output_dim // 2),
|
|
nn.Dropout(0.3),
|
|
nn.LayerNorm(output_dim // 2),
|
|
nn.GELU()
|
|
)
|
|
self.fc3 = nn.Sequential(
|
|
nn.Linear(embed_dim, output_dim),
|
|
nn.Dropout(0.3),
|
|
nn.LayerNorm(output_dim),
|
|
nn.GELU()
|
|
)
|
|
|
|
def forward(self, heatmap: torch.Tensor, t: torch.Tensor):
|
|
# heatmap: [B, C, H, W]
|
|
# t: [B]
|
|
t_embed = self.time_embedding(t)
|
|
x = self.conv1(heatmap) + self.fc1(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
|
x = self.conv2(x) + self.fc2(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
|
x = self.conv3(x) + self.fc3(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
|
return x
|
|
|