mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 损失值计算
This commit is contained in:
parent
3946d83d6c
commit
1962c7a712
@ -403,12 +403,12 @@ class WGANGinkaLoss:
|
||||
return sum(losses)
|
||||
|
||||
class RNNGinkaLoss:
|
||||
def __init__(self, num_classes):
|
||||
def __init__(self, num_classes, device):
|
||||
self.num_classes = num_classes
|
||||
weight = torch.ones(self.num_classes)
|
||||
weight[0] = 0.3
|
||||
weight[1] = 0.5
|
||||
self.weight = weight
|
||||
self.weight = weight.to(device)
|
||||
pass
|
||||
|
||||
def rnn_loss(self, fake, target):
|
||||
@ -416,5 +416,5 @@ class RNNGinkaLoss:
|
||||
fake: [B, C, H, W]
|
||||
target: [B, H, W]
|
||||
"""
|
||||
target = F.one_hot(target, num_classes=self.num_classes).float()
|
||||
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
|
||||
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight)
|
||||
|
||||
@ -86,7 +86,7 @@ def train():
|
||||
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||
|
||||
criterion = RNNGinkaLoss(32)
|
||||
criterion = RNNGinkaLoss(32, device)
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user