mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
144 lines
5.1 KiB
Python
144 lines
5.1 KiB
Python
import os
|
|
from datetime import datetime
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch_geometric.loader import DataLoader
|
|
from tqdm import tqdm
|
|
from .model.model import MinamoModel
|
|
from .model.loss import MinamoLoss
|
|
from .dataset import MinamoDataset
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
os.makedirs("result", exist_ok=True)
|
|
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
|
|
|
epochs = 100
|
|
|
|
def collate_fn(batch):
|
|
"""动态处理不同尺寸地图的批处理"""
|
|
map1_batch = [item[0] for item in batch]
|
|
map2_batch = [item[1] for item in batch]
|
|
vis_sim = torch.cat([item[2] for item in batch])
|
|
topo_sim = torch.cat([item[3] for item in batch])
|
|
|
|
# 保持批次内地图尺寸一致(根据问题描述)
|
|
assert all(m.shape == map1_batch[0].shape for m in map1_batch), \
|
|
"对比地图必须尺寸相同"
|
|
|
|
return (
|
|
torch.stack(map1_batch), # (B, H, W)
|
|
torch.stack(map2_batch), # (B, H, W)
|
|
vis_sim,
|
|
topo_sim
|
|
)
|
|
|
|
def train():
|
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
model = MinamoModel(32)
|
|
model.to(device)
|
|
|
|
# 准备数据集
|
|
dataset = MinamoDataset("minamo-dataset.json")
|
|
val_dataset = MinamoDataset("minamo-eval.json")
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=64,
|
|
shuffle=True
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=64,
|
|
shuffle=True
|
|
)
|
|
|
|
# 设定优化器与调度器
|
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-3)
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
|
criterion = MinamoLoss()
|
|
|
|
# 开始训练
|
|
for epoch in tqdm(range(epochs)):
|
|
model.train()
|
|
total_loss = 0
|
|
|
|
for batch in dataloader:
|
|
# 数据迁移到设备
|
|
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
|
map1 = map1.to(device) # 转为 [B, C, H, W]
|
|
map2 = map2.to(device)
|
|
topo_simi = topo_simi.to(device)
|
|
vision_simi = vision_simi.to(device)
|
|
graph1 = graph1.to(device)
|
|
graph2 = graph2.to(device)
|
|
|
|
# print(map1.shape, map2.shape)
|
|
|
|
# 前向传播
|
|
optimizer.zero_grad()
|
|
vision_pred, topo_pred = model(map1, map2, graph1, graph2)
|
|
|
|
# 计算损失
|
|
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
|
|
|
|
# 反向传播
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
|
|
ave_loss = total_loss / len(dataloader)
|
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
|
|
|
# total_norm = 0
|
|
# for p in model.parameters():
|
|
# if p.grad is not None:
|
|
# param_norm = p.grad.detach().data.norm(2)
|
|
# total_norm += param_norm.item() ** 2
|
|
# total_norm = total_norm ** 0.5
|
|
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
|
|
|
# for name, param in model.named_parameters():
|
|
# if param.grad is not None:
|
|
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
|
|
|
# 学习率调整
|
|
scheduler.step()
|
|
|
|
# 每十轮推理一次验证集
|
|
if (epoch + 1) % 5 == 0:
|
|
model.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for val_batch in val_loader:
|
|
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
|
map1_val = map1_val.to(device)
|
|
map2_val = map2_val.to(device)
|
|
vision_simi_val = vision_simi_val.to(device)
|
|
topo_simi_val = topo_simi_val.to(device)
|
|
graph1 = graph1.to(device)
|
|
graph2 = graph2.to(device)
|
|
|
|
vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2)
|
|
loss_val = criterion(
|
|
vision_pred_val, topo_pred_val,
|
|
vision_simi_val, topo_simi_val
|
|
)
|
|
val_loss += loss_val.item()
|
|
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
|
torch.save({
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
|
|
|
print("Train ended.")
|
|
|
|
torch.save({
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
}, "result/minamo.pth")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(2)
|
|
train()
|