perf: 优化 Minamo 网络参数

This commit is contained in:
unanmed 2025-03-17 11:16:07 +08:00
parent 8d39a7a0c8
commit 7d0f567cc1
2 changed files with 13 additions and 11 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling, GATConv from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv
from torch_geometric.data import Data from torch_geometric.data import Data
class MinamoTopoModel(nn.Module): class MinamoTopoModel(nn.Module):
@ -12,17 +12,18 @@ class MinamoTopoModel(nn.Module):
# 嵌入层 # 嵌入层
self.embedding = torch.nn.Embedding(tile_types, emb_dim) self.embedding = torch.nn.Embedding(tile_types, emb_dim)
# 图卷积层 # 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=4, dropout=0.2) self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
self.conv2 = GATConv(hidden_dim*8, hidden_dim*4, heads=2) self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
self.conv3 = GATConv(hidden_dim*8, out_dim, concat=False) self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
# 正则化 # 正则化
self.norm1 = nn.LayerNorm(hidden_dim*8) self.norm1 = nn.LayerNorm(hidden_dim*16)
self.norm2 = nn.LayerNorm(hidden_dim*8) self.norm2 = nn.LayerNorm(hidden_dim*16)
self.norm3 = nn.LayerNorm(out_dim) self.norm3 = nn.LayerNorm(out_dim)
# 池化层 # 池化层
self.pool = TopKPooling(out_dim, ratio=0.8, nonlinearity=torch.sigmoid) # 保留80%关键节点 self.pool = TopKPooling(out_dim, ratio=0.8) # 保留80%关键节点
self.drop = nn.Dropout(0.3)
# 增强MLP # 增强MLP
self.fc = nn.Sequential( self.fc = nn.Sequential(
@ -45,8 +46,9 @@ class MinamoTopoModel(nn.Module):
x = F.elu(self.norm3(x)) x = F.elu(self.norm3(x))
# 分层池化 # 分层池化
x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch) x = self.drop(x)
x = global_mean_pool(x, batch) # x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch)
x = global_mean_pool(x, graph.batch)
# 增强MLP # 增强MLP
return self.fc(x) return self.fc(x)

View File

@ -52,8 +52,8 @@ def train():
) )
# 设定优化器与调度器 # 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-3) optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
criterion = MinamoLoss() criterion = MinamoLoss()
# 开始训练 # 开始训练