ginka-generator/ginka/train_wgan.py
2025-04-20 22:06:20 +08:00

349 lines
15 KiB
Python

import argparse
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from .model.input import RandomInputHead
from minamo.model.model import MinamoScoreModule
from shared.image import matrix_to_image_cv
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/wgan", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
args = parser.parse_args()
return args
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1, _ = gen(masked1, 1)
fake2, _ = gen(masked2, 2)
fake3, _ = gen(masked3, 3)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if progress_detach:
fake1, x_in = gen(input.detach(), 1, random)
fake2, _ = gen(F.softmax(fake1.detach()), 2)
fake3, _ = gen(F.softmax(fake2.detach()), 3)
else:
fake1, x_in = gen(input, 1, random)
fake2, _ = gen(F.softmax(fake1), 2)
fake3, _ = gen(F.softmax(fake2), 3)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
else:
return fake1, fake2, fake3, x_in
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
c_steps = 5
g_steps = 1
# 训练阶段
train_stage = 1
last_stage = False
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
ginka = GinkaModel().to(device)
ginka_head = RandomInputHead().to(device)
minamo = MinamoScoreModule().to(device)
dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
data_minamo = torch.load(args.state_minamo, map_location=device)
ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
c_steps = data_ginka["c_steps"]
g_steps = data_ginka["g_steps"]
if data_ginka.get("mask_ratio") is not None:
mask_ratio = data_ginka["mask_ratio"]
if data_ginka.get("stage_epoch") is not None:
stage_epoch = data_ginka["stage_epoch"]
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
if data_ginka.get("last_stage") is not None:
last_stage = data_ginka["last_stage"]
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo.get("optim_state") is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
dataset.train_stage = train_stage
dataset.mask_ratio1 = mask_ratio
dataset.mask_ratio2 = mask_ratio
dataset.mask_ratio3 = mask_ratio
dataset_val.train_stage = train_stage
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
print("Train from loaded state.")
low_loss_epochs = 0
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
loss_ce_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
if train_stage == 4:
# 最后一个阶段训练输入头
count = 5 if stage_epoch <= 20 else 2
for _ in range(count):
optimizer_head.zero_grad()
output = F.softmax(ginka_head(masked1), dim=1)
loss_head = criterion.generator_input_head_loss(output)
loss_head.backward()
optimizer_head.step()
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
optimizer_head.zero_grad()
with torch.no_grad():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
# 反向传播
loss_d_avg.backward()
optimizer_minamo.step()
loss_total_minamo += loss_d_avg.detach()
dis_total += dis_avg.detach()
# ---------- 训练生成器
for _ in range(g_steps):
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
optimizer_head.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
if train_stage == 4:
loss_head = criterion.generator_input_head_loss(x_in)
loss_head.backward()
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
)
if avg_loss_ce < 1.0:
low_loss_epochs += 1
else:
low_loss_epochs = 0
# 训练流程控制
if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
if (train_stage == 3 or train_stage == 2) and not last_stage:
if stage_epoch >= 25:
train_stage += 1
stage_epoch = 0
if train_stage == 4:
last_stage = True
if train_stage >= 3 or last_stage:
# 第三阶段后交叉熵损失不再应该生效
mask_ratio = 1.0
if last_stage:
mask_ratio = 1.0
if train_stage == 2 and stage_epoch % 5 == 0:
train_stage = 4
if train_stage == 4 and stage_epoch % 5 == 1:
train_stage = 2
stage_epoch += 1
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 保存检查点
torch.save({
"model_state": ginka.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
"c_steps": c_steps,
"g_steps": g_steps,
"stage": train_stage,
"mask_ratio": mask_ratio,
"stage_epoch": stage_epoch,
"last_stage": last_stage
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),
"optim_state": optimizer_minamo.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
idx = 0
with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
elif train_stage == 3 or train_stage == 4:
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3, _ = gen_total(ginka, input, True, True)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
for i in range(fake1.shape[0]):
for key, one in enumerate([fake1, fake2, fake3]):
map_matrix = one[i]
image = matrix_to_image_cv(map_matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
idx += 1
print("Train ended.")
torch.save({
"model_state": ginka.state_dict(),
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict(),
}, f"result/minamo.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()