ginka-generator/ginka/train.py

84 lines
2.4 KiB
Python

import os
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from tqdm import tqdm
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
epochs = 70
def update_tau(epoch):
start_tau = 1.0
min_tau = 0.1
decay_rate = 0.95
return max(min_tau, start_tau * (decay_rate ** epoch))
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
model = GinkaModel()
model.to(device)
minamo = MinamoModel(32)
minamo.to(device)
minamo.eval()
# 准备数据集
dataset = GinkaDataset("dataset.json", minamo)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
# 开始训练
for epoch in tqdm(range(epochs)):
model.train()
total_loss = 0
model.softmax.tau = update_tau(epoch)
for batch in dataloader:
# 数据迁移到设备
target = batch["target"].to(device)
feat_vec = batch["feat_vec"].to(device)
# 前向传播
optimizer.zero_grad()
output, output_softmax = model(feat_vec)
# 计算损失
loss = criterion(output, output_softmax, target)
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler.step()
print("Train ended.")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, f"result/ginka.pth")
if __name__ == "__main__":
torch.set_num_threads(8)
train()