mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
220 lines
9.4 KiB
Python
220 lines
9.4 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
import torch
|
|
import torch.optim as optim
|
|
import cv2
|
|
from torch_geometric.loader import DataLoader
|
|
from tqdm import tqdm
|
|
from .model.model import GinkaModel
|
|
from .dataset import GinkaWGANDataset
|
|
from .model.loss import WGANGinkaLoss
|
|
from minamo.model.model import MinamoScoreModule
|
|
from minamo.model.similarity import MinamoSimilarityModel
|
|
from shared.graph import batch_convert_soft_map_to_graph
|
|
from shared.image import matrix_to_image_cv
|
|
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
|
|
|
BATCH_SIZE = 32
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
os.makedirs("result", exist_ok=True)
|
|
os.makedirs("result/wgan", exist_ok=True)
|
|
|
|
disable_tqdm = not sys.stdout.isatty()
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description="training codes")
|
|
parser.add_argument("--resume", type=bool, default=False)
|
|
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
|
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
|
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
parser.add_argument("--epochs", type=int, default=100)
|
|
parser.add_argument("--checkpoint", type=int, default=5)
|
|
parser.add_argument("--load_optim", type=bool, default=True)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def clip_weights(model, clip_value=0.01):
|
|
for param in model.parameters():
|
|
param.data = torch.clamp(param.data, -clip_value, clip_value)
|
|
|
|
def train():
|
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
|
|
args = parse_arguments()
|
|
|
|
# c_steps = 1 if args.resume else 5
|
|
# g_steps = 5 if args.resume else 1
|
|
c_steps = 5
|
|
g_steps = 1
|
|
|
|
ginka = GinkaModel()
|
|
minamo = MinamoScoreModule()
|
|
minamo_sim = MinamoSimilarityModel()
|
|
ginka.to(device)
|
|
minamo.to(device)
|
|
minamo_sim.to(device)
|
|
|
|
dataset = GinkaWGANDataset(args.train, device)
|
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
|
|
|
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
|
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
|
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
|
|
|
|
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
|
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
|
|
|
criterion = WGANGinkaLoss()
|
|
|
|
# 用于生成图片
|
|
tile_dict = dict()
|
|
for file in os.listdir('tiles'):
|
|
name = os.path.splitext(file)[0]
|
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
|
|
if args.resume:
|
|
data_ginka = torch.load(args.state_ginka, map_location=device)
|
|
data_minamo = torch.load(args.state_minamo, map_location=device)
|
|
|
|
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
|
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
|
|
|
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
|
|
c_steps = data_ginka["c_steps"]
|
|
g_steps = data_ginka["g_steps"]
|
|
|
|
if args.load_optim:
|
|
if data_ginka["optim_state"] is not None:
|
|
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
|
if data_minamo["optim_state"] is not None:
|
|
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
|
if data_minamo["optim_state_sim"] is not None:
|
|
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
|
|
|
|
print("Train from loaded state.")
|
|
|
|
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
|
loss_total_minamo = torch.Tensor([0]).to(device)
|
|
loss_total_minamo_sim = torch.Tensor([0]).to(device)
|
|
loss_total_ginka = torch.Tensor([0]).to(device)
|
|
dis_total = torch.Tensor([0]).to(device)
|
|
|
|
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
|
batch_size = real_data.size(0)
|
|
real_data = real_data.to(device)
|
|
real_graph = batch_convert_soft_map_to_graph(real_data)
|
|
|
|
# ---------- 训练判别器
|
|
for _ in range(c_steps):
|
|
# 生成假样本
|
|
optimizer_minamo.zero_grad()
|
|
z = torch.rand(batch_size, 1024, device=device)
|
|
fake_data = ginka(z)
|
|
fake_data = fake_data.detach()
|
|
|
|
# 计算判别器输出
|
|
# 反向传播
|
|
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
|
|
loss_d.backward()
|
|
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
|
|
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
|
|
# print("Critic 梯度范数:", total_norm.item())
|
|
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
|
|
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
|
|
optimizer_minamo.step()
|
|
|
|
loss_total_minamo += loss_d.detach()
|
|
dis_total += dis.detach()
|
|
|
|
# ---------- 训练生成器
|
|
|
|
for _ in range(g_steps):
|
|
optimizer_ginka.zero_grad()
|
|
# optimizer_minamo_sim.zero_grad()
|
|
|
|
z1 = torch.randn(batch_size, 1024, device=device)
|
|
z2 = torch.randn(batch_size, 1024, device=device)
|
|
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
|
|
|
|
# 先训练辅助判别器
|
|
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
|
|
# loss_c_assist.backward(retain_graph=True)
|
|
# optimizer_minamo_sim.step()
|
|
|
|
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
|
|
loss_g.backward()
|
|
optimizer_ginka.step()
|
|
|
|
loss_total_ginka += loss_g
|
|
# loss_total_minamo_sim += loss_c_assist.detach()
|
|
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
|
|
|
|
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
|
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
|
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
|
|
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
|
tqdm.write(
|
|
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
|
|
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
|
|
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
|
|
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
|
|
)
|
|
|
|
# scheduler_ginka.step()
|
|
# scheduler_minamo.step()
|
|
|
|
if avg_dis < 0:
|
|
g_steps = max(int(-avg_dis * 5), 1)
|
|
else:
|
|
g_steps = 1
|
|
|
|
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
|
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
|
|
else:
|
|
c_steps = 5
|
|
|
|
# 每若干轮输出一次图片,并保存检查点
|
|
if (epoch + 1) % args.checkpoint == 0:
|
|
# 输出 20 张图片,每批次 4 张,一共五批
|
|
idx = 0
|
|
with torch.no_grad():
|
|
for _ in range(5):
|
|
z = torch.randn(4, 1024, device=device)
|
|
output = ginka(z)
|
|
|
|
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
|
for matrix in map_matrix:
|
|
image = matrix_to_image_cv(matrix, tile_dict)
|
|
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
|
idx += 1
|
|
|
|
# 保存检查点
|
|
torch.save({
|
|
"model_state": ginka.state_dict(),
|
|
"optim_state": optimizer_ginka.state_dict(),
|
|
"c_steps": c_steps,
|
|
"g_steps": g_steps
|
|
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
|
torch.save({
|
|
"model_state": minamo.state_dict(),
|
|
"model_state_sim": minamo_sim.state_dict(),
|
|
"optim_state": optimizer_minamo.state_dict(),
|
|
"optim_state_sim": optimizer_minamo_sim.state_dict()
|
|
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
|
|
|
print("Train ended.")
|
|
torch.save({
|
|
"model_state": ginka.state_dict(),
|
|
}, f"result/ginka.pth")
|
|
torch.save({
|
|
"model_state": minamo.state_dict(),
|
|
"model_state_sim": minamo_sim.state_dict(),
|
|
}, f"result/minamo.pth")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(4)
|
|
train()
|