mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
23 lines
762 B
Python
23 lines
762 B
Python
import torch.nn as nn
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(
|
|
self, d_model=256, dim_ff=512, nhead=8, num_layers=4,
|
|
):
|
|
super().__init__()
|
|
self.encoder = nn.TransformerEncoder(
|
|
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
|
num_layers=num_layers
|
|
)
|
|
|
|
self.decoder = nn.TransformerDecoder(
|
|
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
|
num_layers=num_layers
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: [B, L, d_model]
|
|
m = self.encoder(x)
|
|
out = self.decoder(x, m)
|
|
return out
|
|
|