mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 12:57:15 +08:00
chore: 调整网络参数
This commit is contained in:
parent
01c9e1972e
commit
fee2bcf344
@ -17,11 +17,11 @@ class GinkaMaskGIT(nn.Module):
|
||||
self.cond_encoder = GinkaMaskGITCond(cond_dim=cond_dim, heatmap_channel=heatmap_channel, output_dim=d_model)
|
||||
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True),
|
||||
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
||||
num_layers=num_layers
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True),
|
||||
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
||||
num_layers=num_layers
|
||||
)
|
||||
|
||||
|
||||
@ -48,6 +48,8 @@ BLUR_MIN_SIZE = 3
|
||||
BLUR_MAX_SIZE = 9
|
||||
RAND_RATIO = 0.15
|
||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||
NUM_LAYERS = 4
|
||||
D_MODEL = 128
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -77,7 +79,7 @@ def train():
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device)
|
||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL).to(device)
|
||||
masker = MapMask([0.5, 0.5])
|
||||
|
||||
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user