mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
349 lines
15 KiB
Python
349 lines
15 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
import cv2
|
|
from torch_geometric.loader import DataLoader
|
|
from tqdm import tqdm
|
|
from .model.model import GinkaModel
|
|
from .dataset import GinkaWGANDataset
|
|
from .model.loss import WGANGinkaLoss
|
|
from .model.input import RandomInputHead
|
|
from minamo.model.model import MinamoScoreModule
|
|
from shared.image import matrix_to_image_cv
|
|
|
|
BATCH_SIZE = 16
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
os.makedirs("result", exist_ok=True)
|
|
os.makedirs("result/wgan", exist_ok=True)
|
|
|
|
disable_tqdm = not sys.stdout.isatty()
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description="training codes")
|
|
parser.add_argument("--resume", type=bool, default=False)
|
|
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
|
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
|
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
|
parser.add_argument("--epochs", type=int, default=100)
|
|
parser.add_argument("--checkpoint", type=int, default=5)
|
|
parser.add_argument("--load_optim", type=bool, default=True)
|
|
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
fake1, _ = gen(masked1, 1)
|
|
fake2, _ = gen(masked2, 2)
|
|
fake3, _ = gen(masked3, 3)
|
|
if detach:
|
|
return fake1.detach(), fake2.detach(), fake3.detach()
|
|
else:
|
|
return fake1, fake2, fake3
|
|
|
|
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
|
if progress_detach:
|
|
fake1, x_in = gen(input.detach(), 1, random)
|
|
fake2, _ = gen(F.softmax(fake1.detach()), 2)
|
|
fake3, _ = gen(F.softmax(fake2.detach()), 3)
|
|
else:
|
|
fake1, x_in = gen(input, 1, random)
|
|
fake2, _ = gen(F.softmax(fake1), 2)
|
|
fake3, _ = gen(F.softmax(fake2), 3)
|
|
if result_detach:
|
|
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
|
|
else:
|
|
return fake1, fake2, fake3, x_in
|
|
|
|
def train():
|
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
|
|
args = parse_arguments()
|
|
|
|
c_steps = 5
|
|
g_steps = 1
|
|
# 训练阶段
|
|
train_stage = 1
|
|
last_stage = False
|
|
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
|
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
|
|
|
ginka = GinkaModel().to(device)
|
|
ginka_head = RandomInputHead().to(device)
|
|
minamo = MinamoScoreModule().to(device)
|
|
|
|
dataset = GinkaWGANDataset(args.train, device)
|
|
dataset_val = GinkaWGANDataset(args.validate, device)
|
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
|
optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
|
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
|
|
|
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
|
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
|
|
|
criterion = WGANGinkaLoss()
|
|
|
|
# 用于生成图片
|
|
tile_dict = dict()
|
|
for file in os.listdir('tiles'):
|
|
name = os.path.splitext(file)[0]
|
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
|
|
if args.resume:
|
|
data_ginka = torch.load(args.state_ginka, map_location=device)
|
|
data_minamo = torch.load(args.state_minamo, map_location=device)
|
|
|
|
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
|
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
|
|
|
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
|
|
c_steps = data_ginka["c_steps"]
|
|
g_steps = data_ginka["g_steps"]
|
|
|
|
if data_ginka.get("mask_ratio") is not None:
|
|
mask_ratio = data_ginka["mask_ratio"]
|
|
|
|
if data_ginka.get("stage_epoch") is not None:
|
|
stage_epoch = data_ginka["stage_epoch"]
|
|
|
|
if data_ginka.get("stage") is not None:
|
|
train_stage = data_ginka["stage"]
|
|
|
|
if data_ginka.get("last_stage") is not None:
|
|
last_stage = data_ginka["last_stage"]
|
|
|
|
if args.load_optim:
|
|
if data_ginka.get("optim_state") is not None:
|
|
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
|
if data_minamo.get("optim_state") is not None:
|
|
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
|
|
|
dataset.train_stage = train_stage
|
|
dataset.mask_ratio1 = mask_ratio
|
|
dataset.mask_ratio2 = mask_ratio
|
|
dataset.mask_ratio3 = mask_ratio
|
|
|
|
dataset_val.train_stage = train_stage
|
|
dataset_val.mask_ratio1 = mask_ratio
|
|
dataset_val.mask_ratio2 = mask_ratio
|
|
dataset_val.mask_ratio3 = mask_ratio
|
|
|
|
print("Train from loaded state.")
|
|
|
|
low_loss_epochs = 0
|
|
|
|
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
|
loss_total_minamo = torch.Tensor([0]).to(device)
|
|
loss_total_ginka = torch.Tensor([0]).to(device)
|
|
dis_total = torch.Tensor([0]).to(device)
|
|
loss_ce_total = torch.Tensor([0]).to(device)
|
|
|
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
|
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
|
|
|
if train_stage == 4:
|
|
# 最后一个阶段训练输入头
|
|
count = 5 if stage_epoch <= 20 else 2
|
|
for _ in range(count):
|
|
optimizer_head.zero_grad()
|
|
output = F.softmax(ginka_head(masked1), dim=1)
|
|
loss_head = criterion.generator_input_head_loss(output)
|
|
loss_head.backward()
|
|
optimizer_head.step()
|
|
|
|
# ---------- 训练判别器
|
|
for _ in range(c_steps):
|
|
# 生成假样本
|
|
optimizer_minamo.zero_grad()
|
|
optimizer_ginka.zero_grad()
|
|
optimizer_head.zero_grad()
|
|
|
|
with torch.no_grad():
|
|
if train_stage == 1 or train_stage == 2:
|
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
|
|
|
elif train_stage == 3 or train_stage == 4:
|
|
fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4)
|
|
|
|
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
|
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
|
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
|
|
|
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
|
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
|
|
|
# 反向传播
|
|
loss_d_avg.backward()
|
|
|
|
optimizer_minamo.step()
|
|
|
|
loss_total_minamo += loss_d_avg.detach()
|
|
dis_total += dis_avg.detach()
|
|
|
|
# ---------- 训练生成器
|
|
|
|
for _ in range(g_steps):
|
|
optimizer_minamo.zero_grad()
|
|
optimizer_ginka.zero_grad()
|
|
optimizer_head.zero_grad()
|
|
if train_stage == 1 or train_stage == 2:
|
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
|
|
|
|
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
|
|
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
|
|
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
|
|
|
|
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
|
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
|
|
|
loss_g.backward()
|
|
optimizer_ginka.step()
|
|
loss_total_ginka += loss_g.detach()
|
|
loss_ce_total += loss_ce.detach()
|
|
|
|
elif train_stage == 3 or train_stage == 4:
|
|
fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False)
|
|
|
|
if train_stage == 3:
|
|
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
|
|
else:
|
|
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
|
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
|
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
|
|
|
|
if train_stage == 4:
|
|
loss_head = criterion.generator_input_head_loss(x_in)
|
|
loss_head.backward()
|
|
|
|
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
|
loss_g.backward()
|
|
optimizer_ginka.step()
|
|
loss_total_ginka += loss_g.detach()
|
|
|
|
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
|
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
|
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
|
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
|
tqdm.write(
|
|
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
|
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
|
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
|
)
|
|
|
|
if avg_loss_ce < 1.0:
|
|
low_loss_epochs += 1
|
|
else:
|
|
low_loss_epochs = 0
|
|
|
|
# 训练流程控制
|
|
|
|
if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch:
|
|
if mask_ratio >= 0.9:
|
|
train_stage = 2
|
|
mask_ratio += 0.2
|
|
mask_ratio = min(mask_ratio, 0.9)
|
|
low_loss_epochs = 0
|
|
stage_epoch = 0
|
|
|
|
if (train_stage == 3 or train_stage == 2) and not last_stage:
|
|
if stage_epoch >= 25:
|
|
train_stage += 1
|
|
stage_epoch = 0
|
|
|
|
if train_stage == 4:
|
|
last_stage = True
|
|
|
|
if train_stage >= 3 or last_stage:
|
|
# 第三阶段后交叉熵损失不再应该生效
|
|
mask_ratio = 1.0
|
|
|
|
if last_stage:
|
|
mask_ratio = 1.0
|
|
if train_stage == 2 and stage_epoch % 5 == 0:
|
|
train_stage = 4
|
|
|
|
if train_stage == 4 and stage_epoch % 5 == 1:
|
|
train_stage = 2
|
|
|
|
stage_epoch += 1
|
|
|
|
dataset.train_stage = train_stage
|
|
dataset_val.train_stage = train_stage
|
|
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
|
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
|
|
|
# scheduler_ginka.step()
|
|
# scheduler_minamo.step()
|
|
|
|
if avg_dis < 0:
|
|
g_steps = max(int(-avg_dis * 5), 1)
|
|
else:
|
|
g_steps = 1
|
|
|
|
if avg_loss_minamo > 0:
|
|
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
|
else:
|
|
c_steps = 5
|
|
|
|
# 每若干轮输出一次图片,并保存检查点
|
|
if (epoch + 1) % args.checkpoint == 0:
|
|
# 保存检查点
|
|
torch.save({
|
|
"model_state": ginka.state_dict(),
|
|
"optim_state": optimizer_ginka.state_dict(),
|
|
"c_steps": c_steps,
|
|
"g_steps": g_steps,
|
|
"stage": train_stage,
|
|
"mask_ratio": mask_ratio,
|
|
"stage_epoch": stage_epoch,
|
|
"last_stage": last_stage
|
|
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
|
torch.save({
|
|
"model_state": minamo.state_dict(),
|
|
"optim_state": optimizer_minamo.state_dict()
|
|
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
|
|
|
idx = 0
|
|
with torch.no_grad():
|
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
|
if train_stage == 1 or train_stage == 2:
|
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
|
|
|
elif train_stage == 3 or train_stage == 4:
|
|
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
|
fake1, fake2, fake3, _ = gen_total(ginka, input, True, True)
|
|
|
|
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
|
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
|
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
|
|
|
for i in range(fake1.shape[0]):
|
|
for key, one in enumerate([fake1, fake2, fake3]):
|
|
map_matrix = one[i]
|
|
image = matrix_to_image_cv(map_matrix, tile_dict)
|
|
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
|
|
|
idx += 1
|
|
|
|
print("Train ended.")
|
|
torch.save({
|
|
"model_state": ginka.state_dict(),
|
|
}, f"result/ginka.pth")
|
|
torch.save({
|
|
"model_state": minamo.state_dict(),
|
|
}, f"result/minamo.pth")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(4)
|
|
train()
|