ginka-generator/ginka/maskGIT/maskGIT.py
unanmed 068940cae0 refactor: 采用 VQ + MaskGIT 方案
Co-authored-by: Copilot <copilot@github.com>
2026-04-26 23:45:56 +08:00

29 lines
1.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch.nn as nn
class Transformer(nn.Module):
def __init__(
self, d_model=256, dim_ff=512, nhead=8, num_layers=4,
):
super().__init__()
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
num_layers=num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
num_layers=num_layers
)
def forward(self, x, memory=None):
# x: [B, S, d_model] 地图 token 序列
# memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention
# 若 memory 为 None则退化为原始自编解码行为向后兼容
enc_out = self.encoder(x)
if memory is not None:
# encoder 输出作为 queryz 作为 key/value
out = self.decoder(enc_out, memory)
else:
out = self.decoder(x, enc_out)
return out