ginka-generator/ginka/maskGIT/maskGIT.py
2026-05-15 18:15:50 +08:00

58 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class AdaLN(nn.Module):
# 自适应 LayerNorm条件向量 c 动态预测 LayerNorm 的 gamma 和 beta
def __init__(self, d_model: int, d_cond: int):
super().__init__()
self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
self.proj = nn.Linear(d_cond, d_model * 2)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B, S, d_model] c: [B, d_cond]
gamma, beta = self.proj(c).chunk(2, dim=-1) # 各 [B, d_model]
return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)
class CondTransformerLayer(nn.Module):
# 带 AdaLN 条件注入的 Transformer Encoder 层
# 结构AdaLN → Self-Attn → 残差AdaLN → FFN → 残差Pre-norm
def __init__(self, d_model: int, nhead: int, dim_ff: int, d_cond: int):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.adaln1 = AdaLN(d_model, d_cond)
self.adaln2 = AdaLN(d_model, d_cond)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_model)
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B, S, d_model] c: [B, d_cond]
residual = x
normed = self.adaln1(x, c)
x, _ = self.self_attn(normed, normed, normed)
x = residual + x
residual = x
x = self.ffn(self.adaln2(x, c))
x = residual + x
return x
class Transformer(nn.Module):
# 纯 encoder Transformer每层使用 AdaLN 注入全局条件向量 c
def __init__(
self, d_model: int = 256, dim_ff: int = 512,
nhead: int = 8, num_layers: int = 4
):
super().__init__()
self.layers = nn.ModuleList([
CondTransformerLayer(d_model=d_model, nhead=nhead, dim_ff=dim_ff, d_cond=d_model)
for _ in range(num_layers)
])
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B, S, d_model] c: [B, d_model] 全局条件向量
for layer in self.layers:
x = layer(x, c)
return x