mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 22:41:14 +08:00
44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class ConditionEncoder(nn.Module):
|
|
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
|
|
super().__init__()
|
|
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
|
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
|
self.fusion = nn.Sequential(
|
|
nn.LayerNorm(hidden_dim*2),
|
|
nn.ELU(),
|
|
|
|
nn.Linear(hidden_dim*2, hidden_dim*4),
|
|
nn.LayerNorm(hidden_dim*4),
|
|
nn.ELU(),
|
|
|
|
nn.Linear(hidden_dim*4, out_dim)
|
|
)
|
|
|
|
def forward(self, tag, val):
|
|
tag = self.tag_embed(tag)
|
|
val = self.val_embed(val)
|
|
feat = torch.cat([tag, val], dim=1)
|
|
feat = self.fusion(feat)
|
|
return feat
|
|
|
|
class ConditionInjector(nn.Module):
|
|
def __init__(self, cond_dim, out_dim):
|
|
super().__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(cond_dim, cond_dim*2),
|
|
nn.LayerNorm(cond_dim*2),
|
|
nn.ELU(),
|
|
|
|
nn.Linear(cond_dim*2, out_dim)
|
|
)
|
|
|
|
def forward(self, x, cond):
|
|
cond = self.fc(cond)
|
|
B, D = cond.shape
|
|
cond = cond.view(B, D, 1, 1)
|
|
return x + cond
|