chore: 修改 batch_size

This commit is contained in:
unanmed 2025-12-13 19:14:48 +08:00
parent 27b8c56cd2
commit ae53194694
2 changed files with 5 additions and 6 deletions

View File

@ -138,7 +138,7 @@ class GinkaRNN(nn.Module):
"""
hidden = self.gru(feat_fusion, hidden)
logits = self.fc(hidden)
return logits, hidden
return F.sigmoid(logits), hidden
class GinkaRNNModel(nn.Module):
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):

View File

@ -50,7 +50,7 @@ from shared.image import matrix_to_image_cv
# 26-28. 三种等级的怪物
# 29. 入口,不区分楼梯和箭头
BATCH_SIZE = 8
BATCH_SIZE = 96
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -114,8 +114,7 @@ def train():
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
with torch.autograd.set_detect_anomaly(True):
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
loss = criterion.rnn_loss(fake_logits, target_map)
@ -125,7 +124,7 @@ def train():
iters += 1
# if iters % 100 == 0:
# if iters % 50 == 0:
# avg_loss_ginka = loss_total_ginka.item() / iters
# tqdm.write(
@ -160,7 +159,7 @@ def train():
B, T = val_cond.shape
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
fake_map = fake_map.cpu().numpy()