fix: 损失值计算

This commit is contained in:
unanmed 2025-12-15 12:41:39 +08:00
parent 3946d83d6c
commit 1962c7a712
2 changed files with 4 additions and 4 deletions

View File

@ -403,12 +403,12 @@ class WGANGinkaLoss:
return sum(losses)
class RNNGinkaLoss:
def __init__(self, num_classes):
def __init__(self, num_classes, device):
self.num_classes = num_classes
weight = torch.ones(self.num_classes)
weight[0] = 0.3
weight[1] = 0.5
self.weight = weight
self.weight = weight.to(device)
pass
def rnn_loss(self, fake, target):
@ -416,5 +416,5 @@ class RNNGinkaLoss:
fake: [B, C, H, W]
target: [B, H, W]
"""
target = F.one_hot(target, num_classes=self.num_classes).float()
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight)

View File

@ -86,7 +86,7 @@ def train():
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
criterion = RNNGinkaLoss(32)
criterion = RNNGinkaLoss(32, device)
# 用于生成图片
tile_dict = dict()