ginka-generator/ginka/maskGIT/maskGIT.py

23 lines
762 B
Python

import torch.nn as nn
class Transformer(nn.Module):
def __init__(
self, d_model=256, dim_ff=512, nhead=8, num_layers=4,
):
super().__init__()
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
num_layers=num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
num_layers=num_layers
)
def forward(self, x):
# x: [B, L, d_model]
m = self.encoder(x)
out = self.decoder(x, m)
return out