mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 00:01:13 +08:00
feat: rnn 训练
This commit is contained in:
parent
c79662089b
commit
fa8ded2ecd
@ -8,7 +8,7 @@ import cv2
|
|||||||
from torch_geometric.loader import DataLoader
|
from torch_geometric.loader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .common.cond import ConditionEncoder
|
from .common.cond import ConditionEncoder
|
||||||
from .generator.rnn import GinkaRNN
|
from .generator.rnn import GinkaRNNModel
|
||||||
from .dataset import GinkaRNNDataset
|
from .dataset import GinkaRNNDataset
|
||||||
from .generator.loss import RNNGinkaLoss
|
from .generator.loss import RNNGinkaLoss
|
||||||
from shared.image import matrix_to_image_cv
|
from shared.image import matrix_to_image_cv
|
||||||
@ -76,15 +76,14 @@ def train():
|
|||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
cond_inj = ConditionEncoder().to(device)
|
ginka_rnn = GinkaRNNModel(device).to(device)
|
||||||
ginka_rnn = GinkaRNN().to(device)
|
|
||||||
|
|
||||||
dataset = GinkaRNNDataset(args.train, device)
|
dataset = GinkaRNNDataset(args.train, device)
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj.parameters()), lr=1e-3, betas=(0.0, 0.9))
|
optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||||
|
|
||||||
criterion = RNNGinkaLoss()
|
criterion = RNNGinkaLoss()
|
||||||
@ -112,16 +111,12 @@ def train():
|
|||||||
iters = 0
|
iters = 0
|
||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
tag_cond = batch["tag_cond"].to(device)
|
|
||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
B, D = val_cond.shape
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||||
stage = torch.Tensor([0]).expand(B, 1).to(device)
|
|
||||||
cond_vec = cond_inj(tag_cond, val_cond, stage)
|
|
||||||
fake = ginka_rnn(target_map, cond_vec)
|
|
||||||
|
|
||||||
loss = criterion.rnn_loss(fake, target_map)
|
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer_ginka.step()
|
optimizer_ginka.step()
|
||||||
@ -159,23 +154,25 @@ def train():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
idx = 0
|
idx = 0
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
tag_cond = batch["tag_cond"].to(device)
|
|
||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
B, T = val_cond.shape
|
B, T = val_cond.shape
|
||||||
stage = torch.Tensor([0]).expand(B, 1).to(device)
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||||
cond_vec = cond_inj(tag_cond, val_cond, stage)
|
|
||||||
fake = ginka_rnn(target_map, cond_vec)
|
|
||||||
|
|
||||||
val_loss_total += criterion.rnn_loss(fake, target_map).detach()
|
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
||||||
|
|
||||||
B, T, C = fake.shape
|
fake_map = fake_map.cpu().numpy()
|
||||||
fake_map = torch.argmax(fake, dim=-1).reshape(B, 13, 13).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||||
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
|
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||||
|
tqdm.write(
|
||||||
|
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
|
||||||
|
f"Loss: {avg_loss_val:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user