ginka-generator/ginka/train_wgan.py

165 lines
6.3 KiB
Python

import argparse
import os
from datetime import datetime
import torch
import torch.optim as optim
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from minamo.model.model import MinamoScoreModule
from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/wgan", exist_ok=True)
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/ginka.pth")
parser.add_argument("--state_minamo", type=str, default="result/minamo.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--epochs", type=int, default=100)
args = parser.parse_args()
return args
def clip_weights(model, clip_value=0.01):
for param in model.parameters():
param.data = torch.clamp(param.data, -clip_value, clip_value)
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
c_steps = 1
g_steps = 3
args = parse_arguments()
ginka = GinkaModel()
minamo = MinamoScoreModule()
ginka.to(device)
minamo.to(device)
dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4)
optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5)
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data = torch.load(args.state_ginka, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
data = torch.load(args.state_minamo, map_location=device)
minamo.load_state_dict(data["model_state"], strict=False)
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="GAN Training"):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"):
batch_size = real_data.size(0)
real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data)
optimizer_ginka.zero_grad()
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
z = torch.randn(batch_size, 1024, device=device)
fake_data = ginka(z)
fake_data = fake_data.detach()
# 计算判别器输出
# 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0)
optimizer_minamo.step()
loss_total_minamo += loss_d
dis_total += dis
# ---------- 训练生成器
for _ in range(g_steps):
z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2)
loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
)
if avg_dis < -9:
g_steps = 7
elif avg_dis < -6:
g_steps = 5
elif avg_dis < -3:
g_steps = 3
else:
g_steps = 1
# 每五轮输出一次图片,并保存检查点
if (epoch + 1) % 5 == 0:
# 输出 20 张图片,每批次 4 张,一共五批
idx = 0
with torch.no_grad():
for _ in range(5):
z = torch.randn(4, 1024, device=device)
output = ginka(z)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
idx += 1
# 保存检查点
torch.save({
"model_state": ginka.state_dict()
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
print("Train ended.")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict()
}, f"result/minamo.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()