fix: 引入错误

This commit is contained in:
unanmed 2025-03-17 00:00:52 +08:00
parent 50bb509a84
commit ce03ded9dc
2 changed files with 13 additions and 9 deletions

View File

@ -85,17 +85,21 @@ def train():
optimizer.step()
total_loss += loss.item()
# total_norm = 0
# for p in model.parameters():
# if p.grad is not None:
# param_norm = p.grad.detach().data.norm(2)
# total_norm += param_norm.item() ** 2
# total_norm = total_norm ** 0.5
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
ave_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0
# for p in model.parameters():
# if p.grad is not None:
# param_norm = p.grad.detach().data.norm(2)
# total_norm += param_norm.item() ** 2
# total_norm = total_norm ** 0.5
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
# 学习率调整
scheduler.step()

View File

@ -1,5 +1,5 @@
import torch
from torch_geometric import Data
from torch_geometric.data import Data
def convert_map_to_graph(map):
rows = len(map)