mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 00:01:13 +08:00
58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
class AdaLN(nn.Module):
|
||
# 自适应 LayerNorm:条件向量 c 动态预测 LayerNorm 的 gamma 和 beta
|
||
def __init__(self, d_model: int, d_cond: int):
|
||
super().__init__()
|
||
self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
||
self.proj = nn.Linear(d_cond, d_model * 2)
|
||
|
||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||
# x: [B, S, d_model] c: [B, d_cond]
|
||
gamma, beta = self.proj(c).chunk(2, dim=-1) # 各 [B, d_model]
|
||
return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)
|
||
|
||
class CondTransformerLayer(nn.Module):
|
||
# 带 AdaLN 条件注入的 Transformer Encoder 层
|
||
# 结构:AdaLN → Self-Attn → 残差;AdaLN → FFN → 残差(Pre-norm)
|
||
def __init__(self, d_model: int, nhead: int, dim_ff: int, d_cond: int):
|
||
super().__init__()
|
||
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||
self.adaln1 = AdaLN(d_model, d_cond)
|
||
self.adaln2 = AdaLN(d_model, d_cond)
|
||
self.ffn = nn.Sequential(
|
||
nn.Linear(d_model, dim_ff),
|
||
nn.GELU(),
|
||
nn.Linear(dim_ff, d_model)
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||
# x: [B, S, d_model] c: [B, d_cond]
|
||
residual = x
|
||
normed = self.adaln1(x, c)
|
||
x, _ = self.self_attn(normed, normed, normed)
|
||
x = residual + x
|
||
residual = x
|
||
x = self.ffn(self.adaln2(x, c))
|
||
x = residual + x
|
||
return x
|
||
|
||
class Transformer(nn.Module):
|
||
# 纯 encoder Transformer,每层使用 AdaLN 注入全局条件向量 c
|
||
def __init__(
|
||
self, d_model: int = 256, dim_ff: int = 512,
|
||
nhead: int = 8, num_layers: int = 4
|
||
):
|
||
super().__init__()
|
||
self.layers = nn.ModuleList([
|
||
CondTransformerLayer(d_model=d_model, nhead=nhead, dim_ff=dim_ff, d_cond=d_model)
|
||
for _ in range(num_layers)
|
||
])
|
||
|
||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||
# x: [B, S, d_model] c: [B, d_model] 全局条件向量
|
||
for layer in self.layers:
|
||
x = layer(x, c)
|
||
return x
|