mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 18:31:13 +08:00
feat: 优化图卷积深度
This commit is contained in:
parent
b43a8693ef
commit
41a9e21247
@ -1,7 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class MinamoLoss(nn.Module):
|
class MinamoLoss(nn.Module):
|
||||||
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
def __init__(self, vision_weight=0, topo_weight=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vision_weight = vision_weight
|
self.vision_weight = vision_weight
|
||||||
self.topo_weight = topo_weight
|
self.topo_weight = topo_weight
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling, GATConv
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
class MinamoTopoModel(nn.Module):
|
class MinamoTopoModel(nn.Module):
|
||||||
@ -12,18 +12,42 @@ class MinamoTopoModel(nn.Module):
|
|||||||
# 嵌入层
|
# 嵌入层
|
||||||
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
|
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
|
||||||
# 图卷积层
|
# 图卷积层
|
||||||
self.conv1 = GCNConv(emb_dim, hidden_dim)
|
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=4, dropout=0.2)
|
||||||
self.conv2 = GCNConv(hidden_dim, out_dim)
|
self.conv2 = GATConv(hidden_dim*8, hidden_dim*4, heads=2)
|
||||||
self.fc = torch.nn.Linear(out_dim, mlp_dim) # 降维全连接层
|
self.conv3 = GATConv(hidden_dim*8, out_dim, concat=False)
|
||||||
|
|
||||||
|
# 正则化
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_dim*8)
|
||||||
|
self.norm2 = nn.LayerNorm(hidden_dim*8)
|
||||||
|
self.norm3 = nn.LayerNorm(out_dim)
|
||||||
|
|
||||||
|
# 池化层
|
||||||
|
self.pool = TopKPooling(out_dim, ratio=0.8, nonlinearity=torch.sigmoid) # 保留80%关键节点
|
||||||
|
|
||||||
|
# 增强MLP
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(out_dim, mlp_dim*2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(mlp_dim*2, mlp_dim)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, graph: Data):
|
def forward(self, graph: Data):
|
||||||
x = self.embedding(graph.x)
|
x = self.embedding(graph.x)
|
||||||
|
# identity = x
|
||||||
|
|
||||||
x = self.conv1(x, graph.edge_index)
|
x = self.conv1(x, graph.edge_index)
|
||||||
x = F.relu(x)
|
x = F.elu(self.norm1(x))
|
||||||
|
|
||||||
x = self.conv2(x, graph.edge_index)
|
x = self.conv2(x, graph.edge_index)
|
||||||
x = global_mean_pool(x, graph.batch)
|
x = F.elu(self.norm2(x))
|
||||||
|
|
||||||
# 全连接层降维
|
x = self.conv3(x, graph.edge_index)
|
||||||
x = self.fc(x)
|
x = F.elu(self.norm3(x))
|
||||||
return x # (batch_size, mlp_dim)
|
|
||||||
|
# 分层池化
|
||||||
|
x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch)
|
||||||
|
x = global_mean_pool(x, batch)
|
||||||
|
|
||||||
|
# 增强MLP
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-3)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-3)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||||
criterion = MinamoLoss()
|
criterion = MinamoLoss()
|
||||||
|
|
||||||
@ -85,12 +85,12 @@ def train():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
total_norm = 0
|
# total_norm = 0
|
||||||
for p in model.parameters():
|
# for p in model.parameters():
|
||||||
if p.grad is not None:
|
# if p.grad is not None:
|
||||||
param_norm = p.grad.detach().data.norm(2)
|
# param_norm = p.grad.detach().data.norm(2)
|
||||||
total_norm += param_norm.item() ** 2
|
# total_norm += param_norm.item() ** 2
|
||||||
total_norm = total_norm ** 0.5
|
# total_norm = total_norm ** 0.5
|
||||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||||
|
|
||||||
ave_loss = total_loss / len(dataloader)
|
ave_loss = total_loss / len(dataloader)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user