mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 引入错误
This commit is contained in:
parent
50bb509a84
commit
ce03ded9dc
@ -85,17 +85,21 @@ def train():
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
# if p.grad is not None:
|
||||
# param_norm = p.grad.detach().data.norm(2)
|
||||
# total_norm += param_norm.item() ** 2
|
||||
# total_norm = total_norm ** 0.5
|
||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||
|
||||
ave_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
# if p.grad is not None:
|
||||
# param_norm = p.grad.detach().data.norm(2)
|
||||
# total_norm += param_norm.item() ** 2
|
||||
# total_norm = total_norm ** 0.5
|
||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler.step()
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from torch_geometric import Data
|
||||
from torch_geometric.data import Data
|
||||
|
||||
def convert_map_to_graph(map):
|
||||
rows = len(map)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user