chore: 添加 dropout,调整调度与损失值计算

This commit is contained in:
unanmed 2025-12-15 12:18:34 +08:00
parent df71e9e61c
commit bd821aa8e4
3 changed files with 15 additions and 7 deletions

View File

@ -403,9 +403,13 @@ class WGANGinkaLoss:
return sum(losses)
class RNNGinkaLoss:
def __init__(self):
def __init__(self, num_classes):
self.num_classes = num_classes
pass
def rnn_loss(self, fake, target):
target = F.one_hot(target, num_classes=32).float()
return F.cross_entropy(fake, target, label_smoothing=0.05)
weight = torch.ones(self.num_classes)
weight[0] = 0.3
weight[1] = 0.5
target = F.one_hot(target, num_classes=self.num_classes).float()
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=weight)

View File

@ -32,10 +32,12 @@ class GinkaMapPatch(nn.Module):
self.patch_cnn = nn.Sequential(
nn.Conv2d(tile_classes, 256, 3, padding=1),
nn.Dropout(0.2),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 3),
nn.Dropout(0.2),
nn.BatchNorm2d(512),
nn.ReLU(),
@ -107,7 +109,8 @@ class GinkaInputFusion(nn.Module):
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True,
dropout=0.2
),
num_layers=4
)
@ -128,11 +131,12 @@ class GinkaInputFusion(nn.Module):
return feat[:, 0]
class GinkaRNN(nn.Module):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
super().__init__()
# GRU
self.gru = nn.GRUCell(input_dim, hidden_dim)
self.drop = nn.Dropout(0.2)
self.fc = nn.Linear(hidden_dim, tile_classes)
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
@ -140,7 +144,7 @@ class GinkaRNN(nn.Module):
feat_fusion: [B, input_dim]
hidden: [B, hidden_dim]
"""
hidden = self.gru(feat_fusion, hidden)
hidden = self.drop(self.gru(feat_fusion, hidden))
logits = self.fc(hidden)
return logits, hidden

View File

@ -84,7 +84,7 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
criterion = RNNGinkaLoss()