mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
fix: 插值图片显示
This commit is contained in:
parent
f22f06ef2d
commit
169a514dd1
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import random
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -86,7 +87,7 @@ def train():
|
|||||||
dataset = GinkaRNNDataset(args.train, device)
|
dataset = GinkaRNNDataset(args.train, device)
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||||
@ -186,16 +187,19 @@ def train():
|
|||||||
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
||||||
|
|
||||||
# 插值
|
# 插值
|
||||||
map1 = torch.LongTensor(dataset_val.data[0]["map"]).to(device).reshape(1, 13, 13)
|
val_length = len(dataset_val.data)
|
||||||
map2 = torch.LongTensor(dataset_val.data[1]["map"]).to(device).reshape(1, 13, 13)
|
index1 = random.randint(0, val_length - 1)
|
||||||
|
index2 = random.randint(0, val_length - 1)
|
||||||
|
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
|
||||||
|
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
|
||||||
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
|
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
|
||||||
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
|
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
|
||||||
mu1, logvar1 = vae.encoder(map1_onehot)
|
mu1, logvar1 = vae.encoder(map1_onehot)
|
||||||
mu2, logvar2 = vae.encoder(map2_onehot)
|
mu2, logvar2 = vae.encoder(map2_onehot)
|
||||||
z1 = vae.reparameterize(mu1, logvar1)
|
z1 = vae.reparameterize(mu1, logvar1)
|
||||||
z2 = vae.reparameterize(mu2, logvar2)
|
z2 = vae.reparameterize(mu2, logvar2)
|
||||||
real_img1 = matrix_to_image_cv(map1[0], tile_dict)
|
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)
|
||||||
real_img2 = matrix_to_image_cv(map2[0], tile_dict)
|
real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict)
|
||||||
i = 0
|
i = 0
|
||||||
for t in torch.linspace(0, 1, 8):
|
for t in torch.linspace(0, 1, 8):
|
||||||
z = z1 * (1 - t / 8) + z2 * t / 8
|
z = z1 * (1 - t / 8) + z2 * t / 8
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user