diff --git a/minamo/model/loss.py b/minamo/model/loss.py index fe99bca..671b137 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0.4, topo_weight=0.6): + def __init__(self, vision_weight=0, topo_weight=1): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 6ddb236..6e72f1f 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv, global_mean_pool +from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling, GATConv from torch_geometric.data import Data class MinamoTopoModel(nn.Module): @@ -12,18 +12,42 @@ class MinamoTopoModel(nn.Module): # 嵌入层 self.embedding = torch.nn.Embedding(tile_types, emb_dim) # 图卷积层 - self.conv1 = GCNConv(emb_dim, hidden_dim) - self.conv2 = GCNConv(hidden_dim, out_dim) - self.fc = torch.nn.Linear(out_dim, mlp_dim) # 降维全连接层 + self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=4, dropout=0.2) + self.conv2 = GATConv(hidden_dim*8, hidden_dim*4, heads=2) + self.conv3 = GATConv(hidden_dim*8, out_dim, concat=False) + + # 正则化 + self.norm1 = nn.LayerNorm(hidden_dim*8) + self.norm2 = nn.LayerNorm(hidden_dim*8) + self.norm3 = nn.LayerNorm(out_dim) + + # 池化层 + self.pool = TopKPooling(out_dim, ratio=0.8, nonlinearity=torch.sigmoid) # 保留80%关键节点 + + # 增强MLP + self.fc = nn.Sequential( + nn.Linear(out_dim, mlp_dim*2), + nn.ReLU(), + nn.Linear(mlp_dim*2, mlp_dim) + ) def forward(self, graph: Data): x = self.embedding(graph.x) + # identity = x + x = self.conv1(x, graph.edge_index) - x = F.relu(x) + x = F.elu(self.norm1(x)) + x = self.conv2(x, graph.edge_index) - x = global_mean_pool(x, graph.batch) - - # 全连接层降维 - x = self.fc(x) - return x # (batch_size, mlp_dim) + x = F.elu(self.norm2(x)) + + x = self.conv3(x, graph.edge_index) + x = F.elu(self.norm3(x)) + + # 分层池化 + x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch) + x = global_mean_pool(x, batch) + + # 增强MLP + return self.fc(x) \ No newline at end of file diff --git a/minamo/train.py b/minamo/train.py index 729ae93..5006f0e 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -52,7 +52,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-3) + optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-3) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) criterion = MinamoLoss() @@ -85,12 +85,12 @@ def train(): optimizer.step() total_loss += loss.item() - total_norm = 0 - for p in model.parameters(): - if p.grad is not None: - param_norm = p.grad.detach().data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm ** 0.5 + # total_norm = 0 + # for p in model.parameters(): + # if p.grad is not None: + # param_norm = p.grad.detach().data.norm(2) + # total_norm += param_norm.item() ** 2 + # total_norm = total_norm ** 0.5 # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 ave_loss = total_loss / len(dataloader)