diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 253e26e..7f86800 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -17,11 +17,11 @@ class GinkaMaskGIT(nn.Module): self.cond_encoder = GinkaMaskGITCond(cond_dim=cond_dim, heatmap_channel=heatmap_channel, output_dim=d_model) self.encoder = nn.TransformerEncoder( - nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True), + nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'), num_layers=num_layers ) self.decoder = nn.TransformerDecoder( - nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True), + nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'), num_layers=num_layers ) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 024d03d..399f75e 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -48,6 +48,8 @@ BLUR_MIN_SIZE = 3 BLUR_MAX_SIZE = 9 RAND_RATIO = 0.15 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 +NUM_LAYERS = 4 +D_MODEL = 128 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -77,7 +79,7 @@ def train(): args = parse_arguments() - model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device) + model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL).to(device) masker = MapMask([0.5, 0.5]) dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)