feat: rnn 训练

This commit is contained in:
unanmed 2025-12-13 13:13:22 +08:00
parent c79662089b
commit fa8ded2ecd

View File

@ -8,7 +8,7 @@ import cv2
from torch_geometric.loader import DataLoader from torch_geometric.loader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from .common.cond import ConditionEncoder from .common.cond import ConditionEncoder
from .generator.rnn import GinkaRNN from .generator.rnn import GinkaRNNModel
from .dataset import GinkaRNNDataset from .dataset import GinkaRNNDataset
from .generator.loss import RNNGinkaLoss from .generator.loss import RNNGinkaLoss
from shared.image import matrix_to_image_cv from shared.image import matrix_to_image_cv
@ -76,15 +76,14 @@ def train():
args = parse_arguments() args = parse_arguments()
cond_inj = ConditionEncoder().to(device) ginka_rnn = GinkaRNNModel(device).to(device)
ginka_rnn = GinkaRNN().to(device)
dataset = GinkaRNNDataset(args.train, device) dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device) dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj.parameters()), lr=1e-3, betas=(0.0, 0.9)) optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
criterion = RNNGinkaLoss() criterion = RNNGinkaLoss()
@ -112,16 +111,12 @@ def train():
iters = 0 iters = 0
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device) val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device) target_map = batch["target_map"].to(device)
B, D = val_cond.shape fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
stage = torch.Tensor([0]).expand(B, 1).to(device)
cond_vec = cond_inj(tag_cond, val_cond, stage)
fake = ginka_rnn(target_map, cond_vec)
loss = criterion.rnn_loss(fake, target_map) loss = criterion.rnn_loss(fake_logits, target_map)
loss.backward() loss.backward()
optimizer_ginka.step() optimizer_ginka.step()
@ -159,23 +154,25 @@ def train():
with torch.no_grad(): with torch.no_grad():
idx = 0 idx = 0
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device) val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device) target_map = batch["target_map"].to(device)
B, T = val_cond.shape B, T = val_cond.shape
stage = torch.Tensor([0]).expand(B, 1).to(device) fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
cond_vec = cond_inj(tag_cond, val_cond, stage)
fake = ginka_rnn(target_map, cond_vec)
val_loss_total += criterion.rnn_loss(fake, target_map).detach() val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
B, T, C = fake.shape fake_map = fake_map.cpu().numpy()
fake_map = torch.argmax(fake, dim=-1).reshape(B, 13, 13).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict) fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img) cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
idx += 1 idx += 1
avg_loss_val = val_loss_total.item() / len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
f"Loss: {avg_loss_val:.6f}"
)
print("Train ended.") print("Train ended.")
torch.save({ torch.save({