From 7d0f567cc1b83ed495cd2478a25df702ec380ae5 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 17 Mar 2025 11:16:07 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20Minamo=20=E7=BD=91?= =?UTF-8?q?=E7=BB=9C=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/model/topo.py | 20 +++++++++++--------- minamo/train.py | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 6e72f1f..d1b4220 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling, GATConv +from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv from torch_geometric.data import Data class MinamoTopoModel(nn.Module): @@ -12,17 +12,18 @@ class MinamoTopoModel(nn.Module): # 嵌入层 self.embedding = torch.nn.Embedding(tile_types, emb_dim) # 图卷积层 - self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=4, dropout=0.2) - self.conv2 = GATConv(hidden_dim*8, hidden_dim*4, heads=2) - self.conv3 = GATConv(hidden_dim*8, out_dim, concat=False) + self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2) + self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) + self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False) # 正则化 - self.norm1 = nn.LayerNorm(hidden_dim*8) - self.norm2 = nn.LayerNorm(hidden_dim*8) + self.norm1 = nn.LayerNorm(hidden_dim*16) + self.norm2 = nn.LayerNorm(hidden_dim*16) self.norm3 = nn.LayerNorm(out_dim) # 池化层 - self.pool = TopKPooling(out_dim, ratio=0.8, nonlinearity=torch.sigmoid) # 保留80%关键节点 + self.pool = TopKPooling(out_dim, ratio=0.8) # 保留80%关键节点 + self.drop = nn.Dropout(0.3) # 增强MLP self.fc = nn.Sequential( @@ -45,8 +46,9 @@ class MinamoTopoModel(nn.Module): x = F.elu(self.norm3(x)) # 分层池化 - x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch) - x = global_mean_pool(x, batch) + x = self.drop(x) + # x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch) + x = global_mean_pool(x, graph.batch) # 增强MLP return self.fc(x) diff --git a/minamo/train.py b/minamo/train.py index f33b30a..7db7294 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -52,8 +52,8 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-3) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) criterion = MinamoLoss() # 开始训练