mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 08:21:11 +08:00
84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
import os
|
|
from datetime import datetime
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
from transformers import BertTokenizer
|
|
from tqdm import tqdm
|
|
from .model.model import GinkaModel
|
|
from .model.loss import GinkaLoss
|
|
from .dataset import GinkaDataset
|
|
from minamo.model.model import MinamoModel
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
os.makedirs("result", exist_ok=True)
|
|
|
|
epochs = 70
|
|
|
|
def update_tau(epoch):
|
|
start_tau = 1.0
|
|
min_tau = 0.1
|
|
decay_rate = 0.95
|
|
return max(min_tau, start_tau * (decay_rate ** epoch))
|
|
|
|
def train():
|
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
model = GinkaModel()
|
|
model.to(device)
|
|
minamo = MinamoModel(32)
|
|
minamo.to(device)
|
|
minamo.eval()
|
|
|
|
# 准备数据集
|
|
dataset = GinkaDataset("dataset.json", minamo)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=32,
|
|
shuffle=True
|
|
)
|
|
|
|
# 设定优化器与调度器
|
|
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
|
criterion = GinkaLoss(minamo)
|
|
|
|
# 开始训练
|
|
for epoch in tqdm(range(epochs)):
|
|
model.train()
|
|
total_loss = 0
|
|
model.softmax.tau = update_tau(epoch)
|
|
|
|
for batch in dataloader:
|
|
# 数据迁移到设备
|
|
target = batch["target"].to(device)
|
|
feat_vec = batch["feat_vec"].to(device)
|
|
|
|
# 前向传播
|
|
optimizer.zero_grad()
|
|
output, output_softmax = model(feat_vec)
|
|
|
|
# 计算损失
|
|
loss = criterion(output, output_softmax, target)
|
|
|
|
# 反向传播
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
|
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
|
|
|
# 学习率调整
|
|
scheduler.step()
|
|
|
|
print("Train ended.")
|
|
|
|
torch.save({
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
}, f"result/ginka.pth")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(8)
|
|
train()
|