diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index d3f0cfb..83561f4 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -403,9 +403,13 @@ class WGANGinkaLoss: return sum(losses) class RNNGinkaLoss: - def __init__(self): + def __init__(self, num_classes): + self.num_classes = num_classes pass def rnn_loss(self, fake, target): - target = F.one_hot(target, num_classes=32).float() - return F.cross_entropy(fake, target, label_smoothing=0.05) + weight = torch.ones(self.num_classes) + weight[0] = 0.3 + weight[1] = 0.5 + target = F.one_hot(target, num_classes=self.num_classes).float() + return F.cross_entropy(fake, target, label_smoothing=0.1, weight=weight) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 7499545..50c206b 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -32,10 +32,12 @@ class GinkaMapPatch(nn.Module): self.patch_cnn = nn.Sequential( nn.Conv2d(tile_classes, 256, 3, padding=1), + nn.Dropout(0.2), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 512, 3), + nn.Dropout(0.2), nn.BatchNorm2d(512), nn.ReLU(), @@ -107,7 +109,8 @@ class GinkaInputFusion(nn.Module): self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer( - d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True + d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True, + dropout=0.2 ), num_layers=4 ) @@ -128,11 +131,12 @@ class GinkaInputFusion(nn.Module): return feat[:, 0] class GinkaRNN(nn.Module): - def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048): + def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512): super().__init__() # GRU self.gru = nn.GRUCell(input_dim, hidden_dim) + self.drop = nn.Dropout(0.2) self.fc = nn.Linear(hidden_dim, tile_classes) def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor): @@ -140,7 +144,7 @@ class GinkaRNN(nn.Module): feat_fusion: [B, input_dim] hidden: [B, hidden_dim] """ - hidden = self.gru(feat_fusion, hidden) + hidden = self.drop(self.gru(feat_fusion, hidden)) logits = self.fc(hidden) return logits, hidden diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 9bd7a84..895e077 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -84,7 +84,7 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) - scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) + scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) criterion = RNNGinkaLoss()