mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 修改 batch_size
This commit is contained in:
parent
27b8c56cd2
commit
ae53194694
@ -138,7 +138,7 @@ class GinkaRNN(nn.Module):
|
||||
"""
|
||||
hidden = self.gru(feat_fusion, hidden)
|
||||
logits = self.fc(hidden)
|
||||
return logits, hidden
|
||||
return F.sigmoid(logits), hidden
|
||||
|
||||
class GinkaRNNModel(nn.Module):
|
||||
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
|
||||
|
||||
@ -50,7 +50,7 @@ from shared.image import matrix_to_image_cv
|
||||
# 26-28. 三种等级的怪物
|
||||
# 29. 入口,不区分楼梯和箭头
|
||||
|
||||
BATCH_SIZE = 8
|
||||
BATCH_SIZE = 96
|
||||
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -114,8 +114,7 @@ def train():
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
|
||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||
|
||||
@ -125,7 +124,7 @@ def train():
|
||||
|
||||
iters += 1
|
||||
|
||||
# if iters % 100 == 0:
|
||||
# if iters % 50 == 0:
|
||||
# avg_loss_ginka = loss_total_ginka.item() / iters
|
||||
|
||||
# tqdm.write(
|
||||
@ -160,7 +159,7 @@ def train():
|
||||
|
||||
B, T = val_cond.shape
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
|
||||
|
||||
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
||||
|
||||
fake_map = fake_map.cpu().numpy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user