fix: 损失值计算

This commit is contained in:
unanmed 2025-12-15 12:41:39 +08:00
parent 3946d83d6c
commit 1962c7a712
2 changed files with 4 additions and 4 deletions

View File

@ -403,12 +403,12 @@ class WGANGinkaLoss:
return sum(losses) return sum(losses)
class RNNGinkaLoss: class RNNGinkaLoss:
def __init__(self, num_classes): def __init__(self, num_classes, device):
self.num_classes = num_classes self.num_classes = num_classes
weight = torch.ones(self.num_classes) weight = torch.ones(self.num_classes)
weight[0] = 0.3 weight[0] = 0.3
weight[1] = 0.5 weight[1] = 0.5
self.weight = weight self.weight = weight.to(device)
pass pass
def rnn_loss(self, fake, target): def rnn_loss(self, fake, target):
@ -416,5 +416,5 @@ class RNNGinkaLoss:
fake: [B, C, H, W] fake: [B, C, H, W]
target: [B, H, W] target: [B, H, W]
""" """
target = F.one_hot(target, num_classes=self.num_classes).float() target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight) return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight)

View File

@ -86,7 +86,7 @@ def train():
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
criterion = RNNGinkaLoss(32) criterion = RNNGinkaLoss(32, device)
# 用于生成图片 # 用于生成图片
tile_dict = dict() tile_dict = dict()