mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: rnn 训练
This commit is contained in:
parent
c79662089b
commit
fa8ded2ecd
@ -8,7 +8,7 @@ import cv2
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .common.cond import ConditionEncoder
|
||||
from .generator.rnn import GinkaRNN
|
||||
from .generator.rnn import GinkaRNNModel
|
||||
from .dataset import GinkaRNNDataset
|
||||
from .generator.loss import RNNGinkaLoss
|
||||
from shared.image import matrix_to_image_cv
|
||||
@ -76,15 +76,14 @@ def train():
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
cond_inj = ConditionEncoder().to(device)
|
||||
ginka_rnn = GinkaRNN().to(device)
|
||||
ginka_rnn = GinkaRNNModel(device).to(device)
|
||||
|
||||
dataset = GinkaRNNDataset(args.train, device)
|
||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj.parameters()), lr=1e-3, betas=(0.0, 0.9))
|
||||
optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||
|
||||
criterion = RNNGinkaLoss()
|
||||
@ -112,16 +111,12 @@ def train():
|
||||
iters = 0
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
tag_cond = batch["tag_cond"].to(device)
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
B, D = val_cond.shape
|
||||
stage = torch.Tensor([0]).expand(B, 1).to(device)
|
||||
cond_vec = cond_inj(tag_cond, val_cond, stage)
|
||||
fake = ginka_rnn(target_map, cond_vec)
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
|
||||
loss = criterion.rnn_loss(fake, target_map)
|
||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||
|
||||
loss.backward()
|
||||
optimizer_ginka.step()
|
||||
@ -159,24 +154,26 @@ def train():
|
||||
with torch.no_grad():
|
||||
idx = 0
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
tag_cond = batch["tag_cond"].to(device)
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
B, T = val_cond.shape
|
||||
stage = torch.Tensor([0]).expand(B, 1).to(device)
|
||||
cond_vec = cond_inj(tag_cond, val_cond, stage)
|
||||
fake = ginka_rnn(target_map, cond_vec)
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
|
||||
val_loss_total += criterion.rnn_loss(fake, target_map).detach()
|
||||
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
||||
|
||||
B, T, C = fake.shape
|
||||
fake_map = torch.argmax(fake, dim=-1).reshape(B, 13, 13).cpu().numpy()
|
||||
fake_map = fake_map.cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
|
||||
|
||||
idx += 1
|
||||
|
||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
|
||||
f"Loss: {avg_loss_val:.6f}"
|
||||
)
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": ginka_rnn.state_dict(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user