From 1ccac9e60dfedbfe268f6e3a503c397c860c2245 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 12 Dec 2025 16:41:27 +0800 Subject: [PATCH] fix: rnn --- ginka/generator/loss.py | 1 + ginka/generator/rnn.py | 23 +++++--------- ginka/train_rnn.py | 70 ++++++++++++++++++++++------------------- 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 6246700..1220cbb 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -407,4 +407,5 @@ class RNNGinkaLoss: pass def rnn_loss(self, fake, target): + target = F.one_hot(target, num_classes=32).float() return F.cross_entropy(fake, target) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index e40b55d..d24507e 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F class GinkaRNN(nn.Module): - def __init__(self, tile_classes=32, cond_dim=256, input_dim=256, hidden_dim=512, num_layers=1): + def __init__(self, tile_classes=32, cond_dim=256, input_dim=256, hidden_dim=1024, num_layers=2): super().__init__() # 输入部分 @@ -31,34 +31,25 @@ class GinkaRNN(nn.Module): return logits def print_memory(tag=""): - print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") + print(f"{tag} | 当前显存: {torch.cuda.memory_allocated('cuda:1') / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated('cuda:1') / 1024**2:.2f} MB") if __name__ == "__main__": - input = torch.rand(1, 32, 32, 32).cuda() - tag = torch.rand(1, 64).cuda() - val = torch.rand(1, 16).cuda() + input = torch.argmax(torch.rand(1, 32, 13 * 13).cuda(1), dim=1) + cond = torch.rand(1, 256).cuda(1) # 初始化模型 - model = GinkaRNN().cuda() + model = GinkaRNN().cuda(1) print_memory("初始化后") # 前向传播 start = time.perf_counter() - fake0 = model(input, 0, tag, val) - fake1 = model(F.softmax(fake0, dim=1), 1, tag, val) - fake2 = model(F.softmax(fake1, dim=1), 1, tag, val) - fake3 = model(F.softmax(fake2, dim=1), 1, tag, val) + fake = model(input, cond) end = time.perf_counter() print_memory("前向传播后") print(f"推理耗时: {end - start}") print(f"输入形状: feat={input.shape}") - print(f"输出形状: output={fake3.shape}") - print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}") - print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") - print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") - print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") - print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") + print(f"输出形状: output={fake.shape}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 28d2854..933d143 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -55,6 +55,7 @@ BATCH_SIZE = 8 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/rnn", exist_ok=True) +os.makedirs("result/ginka_rnn_img", exist_ok=True) disable_tqdm = not sys.stdout.isatty() @@ -75,16 +76,16 @@ def train(): args = parse_arguments() - cond_inj = ConditionEncoder() - ginka_rnn = GinkaRNN() + cond_inj = ConditionEncoder().to(device) + ginka_rnn = GinkaRNN().to(device) dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) + optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj.parameters()), lr=1e-3, betas=(0.0, 0.9)) scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) - optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj), lr=1e-3, betas=(0.0, 0.9)) criterion = RNNGinkaLoss() @@ -115,7 +116,9 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - cond_vec = cond_inj(tag_cond, val_cond, 0) + B, D = val_cond.shape + stage = torch.Tensor([0]).expand(B, 1).to(device) + cond_vec = cond_inj(tag_cond, val_cond, stage) fake = ginka_rnn(target_map, cond_vec) loss = criterion.rnn_loss(fake, target_map) @@ -126,18 +129,18 @@ def train(): iters += 1 - if iters % 100 == 0: - avg_loss_ginka = loss_total_ginka.item() / iters + # if iters % 100 == 0: + # avg_loss_ginka = loss_total_ginka.item() / iters - tqdm.write( - f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + - f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" - ) + # tqdm.write( + # f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + # f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + + # f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" + # ) avg_loss_ginka = loss_total_ginka.item() / iters tqdm.write( - f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"[Epoch {epoch} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" ) @@ -150,31 +153,34 @@ def train(): torch.save({ "model_state": ginka_rnn.state_dict(), "optim_state": optimizer_ginka.state_dict(), - }, f"result/wgan/ginka-{epoch + 1}.pth") + }, f"result/rnn/ginka-{epoch + 1}.pth") - val_loss_total = torch.Tensor([0]).to(device) - with torch.no_grad(): - idx = 0 - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - tag_cond = batch["tag_cond"].to(device) - val_cond = batch["val_cond"].to(device) - target_map = batch["target_map"].to(device) - - cond_vec = cond_inj(tag_cond, val_cond, 0) - fake = ginka_rnn(target_map, cond_vec) - - val_loss_total += criterion.rnn_loss(fake, target_map).detach() - - fake_map = torch.argmax(fake, dim=1).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img) - - idx += 1 + val_loss_total = torch.Tensor([0]).to(device) + with torch.no_grad(): + idx = 0 + for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + tag_cond = batch["tag_cond"].to(device) + val_cond = batch["val_cond"].to(device) + target_map = batch["target_map"].to(device) + + B, T = val_cond.shape + stage = torch.Tensor([0]).expand(B, 1).to(device) + cond_vec = cond_inj(tag_cond, val_cond, stage) + fake = ginka_rnn(target_map, cond_vec) + + val_loss_total += criterion.rnn_loss(fake, target_map).detach() + + B, T, C = fake.shape + fake_map = torch.argmax(fake, dim=-1).reshape(B, 13, 13).cpu().numpy() + fake_img = matrix_to_image_cv(fake_map[0], tile_dict) + cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img) + + idx += 1 print("Train ended.") torch.save({ "model_state": ginka_rnn.state_dict(), - }, f"result/ginka.pth") + }, f"result/ginka_rnn.pth") if __name__ == "__main__":