mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 添加 dropout,调整调度与损失值计算
This commit is contained in:
parent
df71e9e61c
commit
bd821aa8e4
@ -403,9 +403,13 @@ class WGANGinkaLoss:
|
||||
return sum(losses)
|
||||
|
||||
class RNNGinkaLoss:
|
||||
def __init__(self):
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
pass
|
||||
|
||||
def rnn_loss(self, fake, target):
|
||||
target = F.one_hot(target, num_classes=32).float()
|
||||
return F.cross_entropy(fake, target, label_smoothing=0.05)
|
||||
weight = torch.ones(self.num_classes)
|
||||
weight[0] = 0.3
|
||||
weight[1] = 0.5
|
||||
target = F.one_hot(target, num_classes=self.num_classes).float()
|
||||
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=weight)
|
||||
|
||||
@ -32,10 +32,12 @@ class GinkaMapPatch(nn.Module):
|
||||
|
||||
self.patch_cnn = nn.Sequential(
|
||||
nn.Conv2d(tile_classes, 256, 3, padding=1),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(256, 512, 3),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
|
||||
@ -107,7 +109,8 @@ class GinkaInputFusion(nn.Module):
|
||||
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True
|
||||
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True,
|
||||
dropout=0.2
|
||||
),
|
||||
num_layers=4
|
||||
)
|
||||
@ -128,11 +131,12 @@ class GinkaInputFusion(nn.Module):
|
||||
return feat[:, 0]
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
|
||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
||||
super().__init__()
|
||||
|
||||
# GRU
|
||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
||||
self.drop = nn.Dropout(0.2)
|
||||
self.fc = nn.Linear(hidden_dim, tile_classes)
|
||||
|
||||
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
|
||||
@ -140,7 +144,7 @@ class GinkaRNN(nn.Module):
|
||||
feat_fusion: [B, input_dim]
|
||||
hidden: [B, hidden_dim]
|
||||
"""
|
||||
hidden = self.gru(feat_fusion, hidden)
|
||||
hidden = self.drop(self.gru(feat_fusion, hidden))
|
||||
logits = self.fc(hidden)
|
||||
return logits, hidden
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||
|
||||
criterion = RNNGinkaLoss()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user