feat: rnn 训练

This commit is contained in:
unanmed 2025-12-13 13:13:22 +08:00
parent c79662089b
commit fa8ded2ecd

View File

@ -8,7 +8,7 @@ import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .common.cond import ConditionEncoder
from .generator.rnn import GinkaRNN
from .generator.rnn import GinkaRNNModel
from .dataset import GinkaRNNDataset
from .generator.loss import RNNGinkaLoss
from shared.image import matrix_to_image_cv
@ -76,15 +76,14 @@ def train():
args = parse_arguments()
cond_inj = ConditionEncoder().to(device)
ginka_rnn = GinkaRNN().to(device)
ginka_rnn = GinkaRNNModel(device).to(device)
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj.parameters()), lr=1e-3, betas=(0.0, 0.9))
optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
criterion = RNNGinkaLoss()
@ -112,16 +111,12 @@ def train():
iters = 0
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
B, D = val_cond.shape
stage = torch.Tensor([0]).expand(B, 1).to(device)
cond_vec = cond_inj(tag_cond, val_cond, stage)
fake = ginka_rnn(target_map, cond_vec)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
loss = criterion.rnn_loss(fake, target_map)
loss = criterion.rnn_loss(fake_logits, target_map)
loss.backward()
optimizer_ginka.step()
@ -159,24 +154,26 @@ def train():
with torch.no_grad():
idx = 0
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
B, T = val_cond.shape
stage = torch.Tensor([0]).expand(B, 1).to(device)
cond_vec = cond_inj(tag_cond, val_cond, stage)
fake = ginka_rnn(target_map, cond_vec)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
val_loss_total += criterion.rnn_loss(fake, target_map).detach()
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
B, T, C = fake.shape
fake_map = torch.argmax(fake, dim=-1).reshape(B, 13, 13).cpu().numpy()
fake_map = fake_map.cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
idx += 1
avg_loss_val = val_loss_total.item() / len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
f"Loss: {avg_loss_val:.6f}"
)
print("Train ended.")
torch.save({
"model_state": ginka_rnn.state_dict(),