mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 验证
This commit is contained in:
parent
19fadf50bd
commit
f35d0b03af
@ -37,7 +37,7 @@ from .maskGIT.mask import MapMask
|
||||
# 6. 道具热力图, 7. 入口热力图, 8. 门热力图
|
||||
|
||||
BATCH_SIZE = 128
|
||||
VAL_BATCH_DIVIDER = 128
|
||||
VAL_BATCH_DIVIDER = 64
|
||||
NUM_CLASSES = 16
|
||||
MASK_TOKEN = 15
|
||||
GENERATE_STEP = 8
|
||||
@ -182,7 +182,7 @@ def train():
|
||||
logits = model(masked_input, cond, heatmap)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
|
||||
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
val_loss_total += loss.detach()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user