mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 插值图片显示
This commit is contained in:
parent
f22f06ef2d
commit
169a514dd1
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -86,7 +87,7 @@ def train():
|
||||
dataset = GinkaRNNDataset(args.train, device)
|
||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||
@ -186,16 +187,19 @@ def train():
|
||||
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
||||
|
||||
# 插值
|
||||
map1 = torch.LongTensor(dataset_val.data[0]["map"]).to(device).reshape(1, 13, 13)
|
||||
map2 = torch.LongTensor(dataset_val.data[1]["map"]).to(device).reshape(1, 13, 13)
|
||||
val_length = len(dataset_val.data)
|
||||
index1 = random.randint(0, val_length - 1)
|
||||
index2 = random.randint(0, val_length - 1)
|
||||
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
|
||||
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
|
||||
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
|
||||
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
|
||||
mu1, logvar1 = vae.encoder(map1_onehot)
|
||||
mu2, logvar2 = vae.encoder(map2_onehot)
|
||||
z1 = vae.reparameterize(mu1, logvar1)
|
||||
z2 = vae.reparameterize(mu2, logvar2)
|
||||
real_img1 = matrix_to_image_cv(map1[0], tile_dict)
|
||||
real_img2 = matrix_to_image_cv(map2[0], tile_dict)
|
||||
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)
|
||||
real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict)
|
||||
i = 0
|
||||
for t in torch.linspace(0, 1, 8):
|
||||
z = z1 * (1 - t / 8) + z2 * t / 8
|
||||
|
||||
Loading…
Reference in New Issue
Block a user