From 8872de3f13129c14f7b5724588a4305e7a66d8c8 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 25 Mar 2025 21:12:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=B9=E8=BF=9B=E5=88=A4=E5=88=AB?= =?UTF-8?q?=E5=99=A8=E4=B8=8E=E7=94=9F=E6=88=90=E5=99=A8=E7=BD=91=E7=BB=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/src/topology/similarity.ts | 3 +- ginka/model/loss.py | 4 +-- ginka/model/model.py | 19 ++++++----- ginka/model/unet.py | 59 +++++++++++++++++++-------------- ginka/train.py | 1 - minamo/model/loss.py | 8 ++--- minamo/model/topo.py | 42 +++++++++-------------- minamo/model/vision.py | 55 ++++-------------------------- requirements.txt | 3 +- shared/attention.py | 17 ++++++++++ shared/graph.py | 40 +++++++++++----------- 11 files changed, 116 insertions(+), 135 deletions(-) diff --git a/data/src/topology/similarity.ts b/data/src/topology/similarity.ts index 83f943d..aa2799d 100644 --- a/data/src/topology/similarity.ts +++ b/data/src/topology/similarity.ts @@ -59,9 +59,10 @@ function weisfeilerLehmanIteration( const neighborLabels = node.neighbors .map(n => n.currentLabel) .sort(); + const compositeLabel = `${node.currentLabel}|${neighborLabels.join( ',' - )}`; + )}`.slice(0, 4096); newLabels.push(compositeLabel); }); diff --git a/ginka/model/loss.py b/ginka/model/loss.py index fa96174..55f3522 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -294,7 +294,7 @@ def entrance_spatial_constraint( return total_loss class GinkaLoss(nn.Module): - def __init__(self, minamo: MinamoModel, weight=[0.5, 0.15, 0.15, 0.1, 0.1]): + def __init__(self, minamo: MinamoModel, weight=[0.5, 0.1, 0.1, 0.2, 0.1]): """Ginka Model 损失函数部分 Args: @@ -335,7 +335,7 @@ class GinkaLoss(nn.Module): losses = [ minamo_loss * self.weight[0], - border_loss * self.weight[1] * 0.1, + border_loss * self.weight[1], entrance_loss * self.weight[2], count_loss * self.weight[3], illegal_loss * self.weight[4] diff --git a/ginka/model/model.py b/ginka/model/model.py index 6415dbd..334cd71 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -16,15 +16,18 @@ class GinkaModel(nn.Module): nn.BatchNorm1d(fc_dim), nn.ReLU() ) - self.deconv_layers = nn.Sequential( - nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=4, stride=2, padding=1), # Upsample 2x - nn.BatchNorm2d(base_ch*4), + self.upsample = nn.Sequential( + nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, stride=1, padding=1), + nn.PixelShuffle(2), + nn.InstanceNorm2d(base_ch*4), nn.ReLU(), - nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=4, stride=2, padding=1), # Upsample 2x - nn.BatchNorm2d(base_ch*2), + nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, stride=1, padding=1), + nn.PixelShuffle(2), + nn.InstanceNorm2d(base_ch*2), nn.ReLU(), - nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=4, stride=2, padding=1), # Upsample 2x - nn.BatchNorm2d(base_ch), + nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, stride=1, padding=1), + nn.PixelShuffle(2), + nn.InstanceNorm2d(base_ch), nn.ReLU(), ) self.unet = GinkaUNet(base_ch, num_classes) @@ -40,7 +43,7 @@ class GinkaModel(nn.Module): """ x = self.fc(feat) x = x.view(-1, self.base_ch*8, 4, 4) - x = self.deconv_layers(x) + x = self.upsample(x) x = self.unet(x) x = F.interpolate(x, (13, 13), mode='bilinear') return x, F.softmax(x, dim=1) diff --git a/ginka/model/unet.py b/ginka/model/unet.py index aaf8903..5a40ee9 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -1,39 +1,50 @@ import torch import torch.nn as nn -from shared.attention import CBAM +import torch.nn.functional as F +from shared.attention import CBAM, SEBlock class GinkaEncoder(nn.Module): """编码器(下采样)部分""" - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, attention=False, block='CBAM'): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), + nn.InstanceNorm2d(out_channels), nn.GELU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - # CBAM(out_channels), - nn.GELU() + nn.InstanceNorm2d(out_channels), ) + # 注意力 + if attention: + if block == 'CBAM': + self.conv.append(CBAM(out_channels)) + elif block == 'SEBlock': + self.conv.append(SEBlock(out_channels)) + self.conv.append(nn.GELU()) self.pool = nn.MaxPool2d(2) def forward(self, x): - x_res = self.conv(x) - x_down = self.pool(x_res) - return x_down, x_res + x_res = self.conv(x) + x_down = self.pool(x_res) + return x_down, x_res class GinkaDecoder(nn.Module): """解码器(上采样)部分""" - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, attention=False, block='CBAM'): super().__init__() self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.conv = nn.Sequential( nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - # CBAM(out_channels), - nn.GELU() + nn.InstanceNorm2d(out_channels), ) + # 注意力 + if attention: + if block == 'CBAM': + self.conv.append(CBAM(out_channels)) + elif block == 'SEBlock': + self.conv.append(SEBlock(out_channels)) + self.conv.append(nn.GELU()) def forward(self, x, skip): x = self.upsample(x) @@ -46,10 +57,11 @@ class GinkaBottleneck(nn.Module): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), + nn.InstanceNorm2d(out_channels), nn.GELU(), + SEBlock(out_channels), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), + nn.InstanceNorm2d(out_channels), nn.GELU(), ) @@ -61,21 +73,20 @@ class GinkaUNet(nn.Module): """Ginka Model UNet 部分 """ super().__init__() - self.down1 = GinkaEncoder(in_ch, in_ch*2) - self.down2 = GinkaEncoder(in_ch*2, in_ch*4) - self.down3 = GinkaEncoder(in_ch*4, in_ch*8) - self.down4 = GinkaEncoder(in_ch*8, in_ch*16) + self.down1 = GinkaEncoder(in_ch, in_ch*2, attention=True) + self.down2 = GinkaEncoder(in_ch*2, in_ch*4, attention=True) + self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock') + self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock') self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16) - self.up1 = GinkaDecoder(in_ch*16, in_ch*8) - self.up2 = GinkaDecoder(in_ch*8, in_ch*4) - self.up3 = GinkaDecoder(in_ch*4, in_ch*2) - self.up4 = GinkaDecoder(in_ch*2, in_ch) + self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock') + self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock') + self.up3 = GinkaDecoder(in_ch*4, in_ch*2, attention=True) + self.up4 = GinkaDecoder(in_ch*2, in_ch, attention=True) self.final = nn.Sequential( nn.Conv2d(in_ch, out_ch, 1), - # nn.Softmax(dim=1) # 适用于分类任务 ) def forward(self, x): diff --git a/ginka/train.py b/ginka/train.py index 97e824e..b4f7b9b 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -95,7 +95,6 @@ def train(): # for name, param in model.named_parameters(): # if param.grad is not None: # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") - avg_loss = total_loss / len(dataloader) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") diff --git a/minamo/model/loss.py b/minamo/model/loss.py index c60a948..58795fc 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,16 +1,16 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0.4, topo_weight=0.6): + def __init__(self, vision_weight=1, topo_weight=0): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight - self.mse = nn.MSELoss() + self.loss = nn.L1Loss() def forward(self, vis_pred, topo_pred, vis_true, topo_true): # print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape) # print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item()) - vis_loss = self.mse(vis_pred, vis_true) - topo_loss = self.mse(topo_pred, topo_true) + vis_loss = self.loss(vis_pred, vis_true) + topo_loss = self.loss(topo_pred, topo_true) # print(vis_loss.item(), topo_loss.item()) return self.vision_weight * vis_loss + self.topo_weight * topo_loss diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 57df620..07a0941 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,8 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.utils import spectral_norm -from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv +from torch_geometric.nn import GATConv, AttentionalAggregation, global_max_pool from torch_geometric.data import Data class MinamoTopoModel(nn.Module): @@ -11,23 +10,19 @@ class MinamoTopoModel(nn.Module): ): super().__init__() # 传入 softmax 概率值,直接映射 - self.input_proj = torch.nn.Linear(tile_types, emb_dim) + self.input_proj = nn.Linear(tile_types, emb_dim) # 图卷积层 self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2) - self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) - self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) - self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2) - self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False) + self.conv2 = GATConv(hidden_dim*16, hidden_dim*2, heads=8) + self.conv3 = GATConv(hidden_dim*16, hidden_dim*2, heads=8) + self.conv4 = GATConv(hidden_dim*16, out_dim, heads=1) # 正则化 self.norm1 = nn.LayerNorm(hidden_dim*16) self.norm2 = nn.LayerNorm(hidden_dim*16) - self.norm_ins2 = nn.LayerNorm(hidden_dim*16) - self.norm_ins1 = nn.LayerNorm(hidden_dim*16) - self.norm3 = nn.LayerNorm(out_dim) + self.norm3 = nn.LayerNorm(hidden_dim*16) + self.norm4 = nn.LayerNorm(out_dim) - # 池化层 - self.pool = TopKPooling(out_dim, ratio=0.8) # 保留80%关键节点 self.drop = nn.Dropout(0.3) # 增强MLP @@ -37,30 +32,25 @@ class MinamoTopoModel(nn.Module): def forward(self, graph: Data): x = self.input_proj(graph.x) - # identity = x x = self.conv1(x, graph.edge_index) - x = F.elu(self.norm1(x)) + x = F.relu(self.norm1(x)) x = self.conv2(x, graph.edge_index) - x = F.elu(self.norm2(x)) - - x = self.conv_ins2(x, graph.edge_index) - x = F.elu(self.norm_ins2(x)) - - x = self.conv_ins1(x, graph.edge_index) - x = F.elu(self.norm_ins1(x)) + x = F.relu(self.norm2(x)) x = self.conv3(x, graph.edge_index) - x = F.elu(self.norm3(x)) + x = F.relu(self.norm3(x)) - # 分层池化 + x = self.conv4(x, graph.edge_index) + x = F.relu(self.norm4(x)) + + # 池化 x = self.drop(x) - # x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch) - x = global_mean_pool(x, graph.batch) + x = global_max_pool(x, graph.batch) topo_vec = self.fc(x) - # 增强MLP + # 归一化 return F.normalize(topo_vec, p=2, dim=-1) \ No newline at end of file diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 20953b2..415b272 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -1,57 +1,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.utils import spectral_norm -from shared.attention import CBAM +from torchvision.models import resnet18 class MinamoVisionModel(nn.Module): - def __init__(self, tile_types=32, conv_ch=64, out_dim=512): + def __init__(self, tile_types=32, out_dim=512): super().__init__() - # 输入 softmax 概率值 - self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1) - - # 卷积部分 - self.vision_conv = nn.Sequential( - nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1), - nn.BatchNorm2d(conv_ch*2), - CBAM(conv_ch*2), - nn.GELU(), - nn.MaxPool2d(2), - nn.Dropout2d(0.4), - - nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1), - nn.BatchNorm2d(conv_ch*4), - CBAM(conv_ch*4), - nn.GELU(), - nn.MaxPool2d(2), - nn.Dropout2d(0.4), - - nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1), - nn.BatchNorm2d(conv_ch*8), - CBAM(conv_ch*8), - nn.GELU(), - # nn.MaxPool2d(2), - # nn.Dropout2d(0.4), - - nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1), - nn.BatchNorm2d(conv_ch*8), - CBAM(conv_ch*8), - nn.GELU(), - - nn.AdaptiveMaxPool2d(1) - ) - - # 输出为向量 - self.vision_head = nn.Sequential( - nn.Dropout(0.4), - nn.Linear(conv_ch*8, out_dim) - ) - - def forward(self, map): - x = self.input_conv(map) - x = self.vision_conv(x) - x = x.view(x.size(0), -1) # 展平 - - vision_vec = self.vision_head(x) + self.resnet = resnet18(num_classes=out_dim) + self.resnet.conv1 = nn.Conv2d(tile_types, 64, kernel_size=7, stride=2, padding=3, bias=False) + def forward(self, x): + vision_vec = self.resnet(x) return F.normalize(vision_vec, p=2, dim=-1) # 归一化 diff --git a/requirements.txt b/requirements.txt index f09e17c..aa7eed8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ torchvision torchaudio tqdm torch-geometric -transformers \ No newline at end of file +transformers +torch-scatter \ No newline at end of file diff --git a/shared/attention.py b/shared/attention.py index c6232d5..c77cc69 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -54,3 +54,20 @@ class CBAM(nn.Module): # 空间注意力 s_att = self.spatial_att(x) return x * s_att + +class SEBlock(nn.Module): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.GELU(), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y \ No newline at end of file diff --git a/shared/graph.py b/shared/graph.py index 48dcb75..bfefa8c 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,24 +1,23 @@ import torch from torch_geometric.data import Data, Batch + def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: """ - 可导的图结构转换(返回PyG Data对象) + 可导的图结构转换(返回 PyG Data 对象) map_probs: [C, H, W] 返回: - Data(x=[N,C], edge_index=[2,E], edge_attr=[E,C]) + Data(x=[N, C], edge_index=[2, E], edge_attr=[E, C]) """ C, H, W = map_probs.shape device = map_probs.device N = H * W - # 1. 节点特征(保留所有节点) + # 1. 节点特征 node_features = map_probs.view(C, -1).T # [N, C] - # 2. 构建所有可能的边连接(预计算) - # 生成坐标网格 - rows, cols = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') - node_indices = rows * W + cols + # 2. 构建所有可能的边连接 + node_indices = torch.arange(N, device=device).view(H, W) # 水平连接(右邻居) right_src = node_indices[:, :-1].flatten() @@ -28,20 +27,23 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: down_src = node_indices[:-1, :].flatten() down_dst = node_indices[1:, :].flatten() - # 合并边列表 - edge_src = torch.cat([right_src, down_src]).to(device) - edge_dst = torch.cat([right_dst, down_dst]).to(device) - edge_index = torch.stack([edge_src, edge_dst]) # [2, E] + # 合并边列表(双向) + edge_src = torch.cat([right_src, down_src]) + edge_dst = torch.cat([right_dst, down_dst]) + edge_index = torch.cat([ + torch.stack([edge_src, edge_dst], dim=0), + torch.stack([edge_dst, edge_src], dim=0) # 反向连接 + ], dim=1).to(device, dtype=torch.long) - # 3. 计算可导的边权重(排除墙类型) - wall_class_idx = 1 # 假设类型1是墙 - src_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_src] # [E] - dst_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_dst] # [E] - edge_mask = (src_probs * dst_probs).unsqueeze(1) # [E, 1] + # 3. 计算可导的边权重 + wall_class_idx = 1 # 假设类别 1 是墙 + src_probs = torch.sigmoid(-map_probs[wall_class_idx].flatten()[edge_src]) + dst_probs = torch.sigmoid(-map_probs[wall_class_idx].flatten()[edge_dst]) + edge_mask = torch.nn.functional.softplus(src_probs * dst_probs).unsqueeze(1) # [E, 1] - # 4. 边特征计算(保持可导) - src_feat = map_probs[:, edge_src//W, edge_src%W].T # [E, C] - dst_feat = map_probs[:, edge_dst//W, edge_dst%W].T # [E, C] + # 4. 计算边特征 + src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C] + dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C] edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C] return Data(