From ce03ded9dcb5eeda74f341f55b2018e9ff3d1937 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 17 Mar 2025 00:00:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=BC=95=E5=85=A5=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/train.py | 20 ++++++++++++-------- shared/graph.py | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/minamo/train.py b/minamo/train.py index 5006f0e..f33b30a 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -85,17 +85,21 @@ def train(): optimizer.step() total_loss += loss.item() - # total_norm = 0 - # for p in model.parameters(): - # if p.grad is not None: - # param_norm = p.grad.detach().data.norm(2) - # total_norm += param_norm.item() ** 2 - # total_norm = total_norm ** 0.5 - # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 - ave_loss = total_loss / len(dataloader) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") + # total_norm = 0 + # for p in model.parameters(): + # if p.grad is not None: + # param_norm = p.grad.detach().data.norm(2) + # total_norm += param_norm.item() ** 2 + # total_norm = total_norm ** 0.5 + # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 + + # for name, param in model.named_parameters(): + # if param.grad is not None: + # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") + # 学习率调整 scheduler.step() diff --git a/shared/graph.py b/shared/graph.py index 187051c..acbed92 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,5 +1,5 @@ import torch -from torch_geometric import Data +from torch_geometric.data import Data def convert_map_to_graph(map): rows = len(map)