mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 20:41:12 +08:00
chore: 修改 batch_size
This commit is contained in:
parent
27b8c56cd2
commit
ae53194694
@ -138,7 +138,7 @@ class GinkaRNN(nn.Module):
|
|||||||
"""
|
"""
|
||||||
hidden = self.gru(feat_fusion, hidden)
|
hidden = self.gru(feat_fusion, hidden)
|
||||||
logits = self.fc(hidden)
|
logits = self.fc(hidden)
|
||||||
return logits, hidden
|
return F.sigmoid(logits), hidden
|
||||||
|
|
||||||
class GinkaRNNModel(nn.Module):
|
class GinkaRNNModel(nn.Module):
|
||||||
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
|
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
|
||||||
|
|||||||
@ -50,7 +50,7 @@ from shared.image import matrix_to_image_cv
|
|||||||
# 26-28. 三种等级的怪物
|
# 26-28. 三种等级的怪物
|
||||||
# 29. 入口,不区分楼梯和箭头
|
# 29. 入口,不区分楼梯和箭头
|
||||||
|
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 96
|
||||||
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
@ -114,8 +114,7 @@ def train():
|
|||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
with torch.autograd.set_detect_anomaly(True):
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
|
||||||
|
|
||||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||||
|
|
||||||
@ -125,7 +124,7 @@ def train():
|
|||||||
|
|
||||||
iters += 1
|
iters += 1
|
||||||
|
|
||||||
# if iters % 100 == 0:
|
# if iters % 50 == 0:
|
||||||
# avg_loss_ginka = loss_total_ginka.item() / iters
|
# avg_loss_ginka = loss_total_ginka.item() / iters
|
||||||
|
|
||||||
# tqdm.write(
|
# tqdm.write(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user