mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 22:41:14 +08:00
114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
import os
|
|
from datetime import datetime
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
from transformers import BertTokenizer
|
|
from tqdm import tqdm
|
|
from .model.model import GinkaModel
|
|
from .model.loss import GinkaLoss
|
|
from .dataset import GinkaDataset
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
os.makedirs("result", exist_ok=True)
|
|
|
|
epochs = 100
|
|
|
|
def collate_fn(batch):
|
|
# 动态填充噪声到最大尺寸
|
|
max_h = max([b["noise"].shape[0] for b in batch])
|
|
max_w = max([b["noise"].shape[1] for b in batch])
|
|
|
|
padded_batch = {}
|
|
for key in ["noise", "target"]:
|
|
padded = []
|
|
for b in batch:
|
|
tensor = b[key]
|
|
pad_h = max_h - tensor.shape[0]
|
|
pad_w = max_w - tensor.shape[1]
|
|
padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0))
|
|
padded_batch[key] = torch.stack(padded)
|
|
|
|
# 其他字段直接堆叠
|
|
for key in ["input_ids", "attention_mask", "map_size"]:
|
|
padded_batch[key] = torch.stack([b[key] for b in batch])
|
|
|
|
return padded_batch
|
|
|
|
def train():
|
|
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
|
|
model = GinkaModel()
|
|
model.to(device)
|
|
|
|
# 准备数据集
|
|
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese')
|
|
dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=4,
|
|
shuffle=True,
|
|
collate_fn=collate_fn,
|
|
num_workers=0
|
|
)
|
|
|
|
# 设定优化器与调度器
|
|
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
|
criterion = GinkaLoss()
|
|
|
|
# 开始训练
|
|
for epoch in tqdm(range(epochs)):
|
|
model.train()
|
|
total_loss = 0
|
|
|
|
# 温度退火
|
|
model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs)
|
|
|
|
for batch in dataloader:
|
|
# 数据迁移到设备
|
|
noise = batch["noise"].to(device)
|
|
input_ids = batch["input_ids"].to(device)
|
|
attention_mask = batch["attention_mask"].to(device)
|
|
map_size = batch["map_size"].to(device)
|
|
target = batch["target"].to(device)
|
|
|
|
# 前向传播
|
|
optimizer.zero_grad()
|
|
outputs = model(noise, input_ids, attention_mask, map_size)
|
|
|
|
print(torch.argmax(torch.softmax(outputs, dim=1), dim=1))
|
|
# print(sampled[0, :, :, 1])
|
|
|
|
# 构建拓扑图
|
|
# with torch.no_grad():
|
|
# pred_graphs = build_topology_graph(outputs.argmax(1))
|
|
# ref_graphs = build_topology_graph(target)
|
|
|
|
# 计算损失
|
|
loss = criterion(
|
|
outputs, # 调整为 [BS, C, H, W]
|
|
target
|
|
)
|
|
|
|
# 反向传播
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
|
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
|
|
|
# 学习率调整
|
|
scheduler.step()
|
|
|
|
print("Train ended.")
|
|
|
|
torch.save({
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
}, f"result/ginka.pth")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(8)
|
|
train()
|