ginka-generator/ginka/train_wgan.py
2025-04-10 22:42:58 +08:00

220 lines
9.4 KiB
Python

import argparse
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from minamo.model.model import MinamoScoreModule
from minamo.model.similarity import MinamoSimilarityModel
from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/wgan", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args()
return args
def clip_weights(model, clip_value=0.01):
for param in model.parameters():
param.data = torch.clamp(param.data, -clip_value, clip_value)
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
# c_steps = 1 if args.resume else 5
# g_steps = 5 if args.resume else 1
c_steps = 5
g_steps = 1
ginka = GinkaModel()
minamo = MinamoScoreModule()
minamo_sim = MinamoSimilarityModel()
ginka.to(device)
minamo.to(device)
minamo_sim.to(device)
dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
data_minamo = torch.load(args.state_minamo, map_location=device)
ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
c_steps = data_ginka["c_steps"]
g_steps = data_ginka["g_steps"]
if args.load_optim:
if data_ginka["optim_state"] is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo["optim_state"] is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
if data_minamo["optim_state_sim"] is not None:
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_minamo_sim = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
batch_size = real_data.size(0)
real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data)
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
z = torch.rand(batch_size, 1024, device=device)
fake_data = ginka(z)
fake_data = fake_data.detach()
# 计算判别器输出
# 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
# print("Critic 梯度范数:", total_norm.item())
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
optimizer_minamo.step()
loss_total_minamo += loss_d.detach()
dis_total += dis.detach()
# ---------- 训练生成器
for _ in range(g_steps):
optimizer_ginka.zero_grad()
# optimizer_minamo_sim.zero_grad()
z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
# 先训练辅助判别器
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
# loss_c_assist.backward(retain_graph=True)
# optimizer_minamo_sim.step()
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g
# loss_total_minamo_sim += loss_c_assist.detach()
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
)
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
else:
c_steps = 5
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 输出 20 张图片,每批次 4 张,一共五批
idx = 0
with torch.no_grad():
for _ in range(5):
z = torch.randn(4, 1024, device=device)
output = ginka(z)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
idx += 1
# 保存检查点
torch.save({
"model_state": ginka.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
"c_steps": c_steps,
"g_steps": g_steps
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
"optim_state": optimizer_minamo.state_dict(),
"optim_state_sim": optimizer_minamo_sim.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
print("Train ended.")
torch.save({
"model_state": ginka.state_dict(),
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
}, f"result/minamo.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()