chore: 调整网络参数

This commit is contained in:
unanmed 2026-03-30 14:40:58 +08:00
parent 01c9e1972e
commit fee2bcf344
2 changed files with 5 additions and 3 deletions

View File

@ -17,11 +17,11 @@ class GinkaMaskGIT(nn.Module):
self.cond_encoder = GinkaMaskGITCond(cond_dim=cond_dim, heatmap_channel=heatmap_channel, output_dim=d_model)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True),
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
num_layers=num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True),
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
num_layers=num_layers
)

View File

@ -48,6 +48,8 @@ BLUR_MIN_SIZE = 3
BLUR_MAX_SIZE = 9
RAND_RATIO = 0.15
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
NUM_LAYERS = 4
D_MODEL = 128
device = torch.device(
"cuda:1" if torch.cuda.is_available()
@ -77,7 +79,7 @@ def train():
args = parse_arguments()
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device)
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL).to(device)
masker = MapMask([0.5, 0.5])
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)