ginka-generator/ginka/train.py
2025-03-15 22:26:31 +08:00

114 lines
3.6 KiB
Python

import os
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from tqdm import tqdm
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
epochs = 100
def collate_fn(batch):
# 动态填充噪声到最大尺寸
max_h = max([b["noise"].shape[0] for b in batch])
max_w = max([b["noise"].shape[1] for b in batch])
padded_batch = {}
for key in ["noise", "target"]:
padded = []
for b in batch:
tensor = b[key]
pad_h = max_h - tensor.shape[0]
pad_w = max_w - tensor.shape[1]
padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0))
padded_batch[key] = torch.stack(padded)
# 其他字段直接堆叠
for key in ["input_ids", "attention_mask", "map_size"]:
padded_batch[key] = torch.stack([b[key] for b in batch])
return padded_batch
def train():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
model = GinkaModel()
model.to(device)
# 准备数据集
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese')
dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer)
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=collate_fn,
num_workers=0
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = GinkaLoss()
# 开始训练
for epoch in tqdm(range(epochs)):
model.train()
total_loss = 0
# 温度退火
model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs)
for batch in dataloader:
# 数据迁移到设备
noise = batch["noise"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
map_size = batch["map_size"].to(device)
target = batch["target"].to(device)
# 前向传播
optimizer.zero_grad()
outputs = model(noise, input_ids, attention_mask, map_size)
print(torch.argmax(torch.softmax(outputs, dim=1), dim=1))
# print(sampled[0, :, :, 1])
# 构建拓扑图
# with torch.no_grad():
# pred_graphs = build_topology_graph(outputs.argmax(1))
# ref_graphs = build_topology_graph(target)
# 计算损失
loss = criterion(
outputs, # 调整为 [BS, C, H, W]
target
)
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler.step()
print("Train ended.")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, f"result/ginka.pth")
if __name__ == "__main__":
torch.set_num_threads(8)
train()