mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
chore: 调整网络参数
This commit is contained in:
parent
01c9e1972e
commit
fee2bcf344
@ -17,11 +17,11 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
self.cond_encoder = GinkaMaskGITCond(cond_dim=cond_dim, heatmap_channel=heatmap_channel, output_dim=d_model)
|
self.cond_encoder = GinkaMaskGITCond(cond_dim=cond_dim, heatmap_channel=heatmap_channel, output_dim=d_model)
|
||||||
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
self.encoder = nn.TransformerEncoder(
|
||||||
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True),
|
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
||||||
num_layers=num_layers
|
num_layers=num_layers
|
||||||
)
|
)
|
||||||
self.decoder = nn.TransformerDecoder(
|
self.decoder = nn.TransformerDecoder(
|
||||||
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True),
|
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
||||||
num_layers=num_layers
|
num_layers=num_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -48,6 +48,8 @@ BLUR_MIN_SIZE = 3
|
|||||||
BLUR_MAX_SIZE = 9
|
BLUR_MAX_SIZE = 9
|
||||||
RAND_RATIO = 0.15
|
RAND_RATIO = 0.15
|
||||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||||
|
NUM_LAYERS = 4
|
||||||
|
D_MODEL = 128
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -77,7 +79,7 @@ def train():
|
|||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device)
|
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL).to(device)
|
||||||
masker = MapMask([0.5, 0.5])
|
masker = MapMask([0.5, 0.5])
|
||||||
|
|
||||||
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user