mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
fix: 交叉熵损失计算
This commit is contained in:
parent
102e19cefb
commit
3946d83d6c
@ -405,6 +405,10 @@ class WGANGinkaLoss:
|
|||||||
class RNNGinkaLoss:
|
class RNNGinkaLoss:
|
||||||
def __init__(self, num_classes):
|
def __init__(self, num_classes):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
weight = torch.ones(self.num_classes)
|
||||||
|
weight[0] = 0.3
|
||||||
|
weight[1] = 0.5
|
||||||
|
self.weight = weight
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def rnn_loss(self, fake, target):
|
def rnn_loss(self, fake, target):
|
||||||
@ -412,8 +416,5 @@ class RNNGinkaLoss:
|
|||||||
fake: [B, C, H, W]
|
fake: [B, C, H, W]
|
||||||
target: [B, H, W]
|
target: [B, H, W]
|
||||||
"""
|
"""
|
||||||
weight = torch.ones(self.num_classes)
|
|
||||||
weight[0] = 0.3
|
|
||||||
weight[1] = 0.5
|
|
||||||
target = F.one_hot(target, num_classes=self.num_classes).float()
|
target = F.one_hot(target, num_classes=self.num_classes).float()
|
||||||
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=weight)
|
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight)
|
||||||
|
|||||||
@ -205,7 +205,7 @@ class GinkaRNNModel(nn.Module):
|
|||||||
map[:, y, x] = tile_id[:]
|
map[:, y, x] = tile_id[:]
|
||||||
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
||||||
|
|
||||||
return output_logits, map
|
return output_logits.permute(0, 3, 1, 2), map
|
||||||
|
|
||||||
def print_memory(device, tag=""):
|
def print_memory(device, tag=""):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|||||||
@ -86,7 +86,7 @@ def train():
|
|||||||
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||||
|
|
||||||
criterion = RNNGinkaLoss()
|
criterion = RNNGinkaLoss(32)
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
@ -158,7 +158,6 @@ def train():
|
|||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
B, T = val_cond.shape
|
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||||
|
|
||||||
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user