From 169a514dd13cfde9b0dbe38210722b1d1a907d21 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 20 Jan 2026 00:00:24 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8F=92=E5=80=BC=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 74844a7..1589a8a 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -1,6 +1,7 @@ import argparse import os import sys +import random from datetime import datetime import torch import torch.nn.functional as F @@ -86,7 +87,7 @@ def train(): dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True) optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) @@ -186,16 +187,19 @@ def train(): cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img) # 插值 - map1 = torch.LongTensor(dataset_val.data[0]["map"]).to(device).reshape(1, 13, 13) - map2 = torch.LongTensor(dataset_val.data[1]["map"]).to(device).reshape(1, 13, 13) + val_length = len(dataset_val.data) + index1 = random.randint(0, val_length - 1) + index2 = random.randint(0, val_length - 1) + map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13) + map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13) map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device) map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device) mu1, logvar1 = vae.encoder(map1_onehot) mu2, logvar2 = vae.encoder(map2_onehot) z1 = vae.reparameterize(mu1, logvar1) z2 = vae.reparameterize(mu2, logvar2) - real_img1 = matrix_to_image_cv(map1[0], tile_dict) - real_img2 = matrix_to_image_cv(map2[0], tile_dict) + real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict) + real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict) i = 0 for t in torch.linspace(0, 1, 8): z = z1 * (1 - t / 8) + z2 * t / 8