fix: 插值图片显示

This commit is contained in:
unanmed 2026-01-20 00:00:24 +08:00
parent f22f06ef2d
commit 169a514dd1

View File

@ -1,6 +1,7 @@
import argparse
import os
import sys
import random
from datetime import datetime
import torch
import torch.nn.functional as F
@ -86,7 +87,7 @@ def train():
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
@ -186,16 +187,19 @@ def train():
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
# 插值
map1 = torch.LongTensor(dataset_val.data[0]["map"]).to(device).reshape(1, 13, 13)
map2 = torch.LongTensor(dataset_val.data[1]["map"]).to(device).reshape(1, 13, 13)
val_length = len(dataset_val.data)
index1 = random.randint(0, val_length - 1)
index2 = random.randint(0, val_length - 1)
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
mu1, logvar1 = vae.encoder(map1_onehot)
mu2, logvar2 = vae.encoder(map2_onehot)
z1 = vae.reparameterize(mu1, logvar1)
z2 = vae.reparameterize(mu2, logvar2)
real_img1 = matrix_to_image_cv(map1[0], tile_dict)
real_img2 = matrix_to_image_cv(map2[0], tile_dict)
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)
real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict)
i = 0
for t in torch.linspace(0, 1, 8):
z = z1 * (1 - t / 8) + z2 * t / 8