mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
refactor: GINKA 生成器使用 Minamo 作为损失值的一部分
This commit is contained in:
parent
1566acf691
commit
09c63fedce
@ -15,10 +15,11 @@ def load_data(path: str):
|
|||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
class GinkaDataset(Dataset):
|
class GinkaDataset(Dataset):
|
||||||
def __init__(self, data_path: str, minamo: MinamoModel):
|
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path) # 自定义数据加载函数
|
||||||
self.max_size = 32
|
self.max_size = 32
|
||||||
self.minamo = minamo
|
self.minamo = minamo
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -26,13 +27,13 @@ class GinkaDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = torch.tensor(item["map"])
|
target = torch.tensor(item["map"]).to(self.device)
|
||||||
graph = convert_map_to_graph(target)
|
graph = convert_map_to_graph(target).to(self.device)
|
||||||
vision_feat, topo_feat = self.minamo(target, graph)
|
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||||
feat_vec = torch.cat([vision_feat, topo_feat])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"feat_vec": feat_vec,
|
"target_vision_feat": vision_feat,
|
||||||
|
"target_topo_feat": topo_feat,
|
||||||
"target": target
|
"target": target
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
|
from shared.graph import DynamicGraphConverter
|
||||||
|
|
||||||
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
||||||
"""地图最外层是否为墙"""
|
"""地图最外层是否为墙"""
|
||||||
@ -131,7 +132,6 @@ def entrance_distance_and_presence_loss(
|
|||||||
total_loss: 综合入口距离与存在性损失
|
total_loss: 综合入口距离与存在性损失
|
||||||
"""
|
"""
|
||||||
# 将 logits 转换为概率分布
|
# 将 logits 转换为概率分布
|
||||||
probs = F.softmax(logits, dim=1) # [B, C, H, W]
|
|
||||||
B, C, H, W = logits.shape
|
B, C, H, W = logits.shape
|
||||||
|
|
||||||
# 提取箭头和楼梯的概率图
|
# 提取箭头和楼梯的概率图
|
||||||
@ -147,9 +147,9 @@ def entrance_distance_and_presence_loss(
|
|||||||
arrow_distance_loss = F.relu(arrow_excess).mean()
|
arrow_distance_loss = F.relu(arrow_excess).mean()
|
||||||
|
|
||||||
# 楼梯:使用窗口大小为 (W//2, H//2)
|
# 楼梯:使用窗口大小为 (W//2, H//2)
|
||||||
kernel_size_stairs = (max(1, W // 2), max(1, H // 2))
|
kernel_size_stairs = (9, 9)
|
||||||
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device)
|
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device)
|
||||||
pad_stairs = (kernel_size_stairs[0] // 2, kernel_size_stairs[1] // 2)
|
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
|
||||||
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
|
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
|
||||||
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
|
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
|
||||||
stairs_distance_loss = F.relu(stairs_excess).mean()
|
stairs_distance_loss = F.relu(stairs_excess).mean()
|
||||||
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
|
|||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
class GinkaLoss(nn.Module):
|
class GinkaLoss(nn.Module):
|
||||||
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
def __init__(self, minamo: MinamoModel, converter: DynamicGraphConverter, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||||
"""Ginka Model 损失函数部分
|
"""Ginka Model 损失函数部分
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -301,8 +301,10 @@ class GinkaLoss(nn.Module):
|
|||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.ce = nn.CrossEntropyLoss()
|
self.ce = nn.CrossEntropyLoss()
|
||||||
self.minamo = minamo
|
self.minamo = minamo
|
||||||
|
self.tau = 1
|
||||||
|
self.converter = converter
|
||||||
|
|
||||||
def forward(self, pred, pred_softmax, target):
|
def forward(self, pred, pred_softmax, target, target_vision_feat, target_topo_feat):
|
||||||
probs = F.softmax(pred, dim=1)
|
probs = F.softmax(pred, dim=1)
|
||||||
# 地图结构损失
|
# 地图结构损失
|
||||||
border_loss = wall_border_loss(pred, probs)
|
border_loss = wall_border_loss(pred, probs)
|
||||||
@ -314,21 +316,27 @@ class GinkaLoss(nn.Module):
|
|||||||
count_loss = integrated_count_loss(probs, target)
|
count_loss = integrated_count_loss(probs, target)
|
||||||
|
|
||||||
# 使用 Minamo Model 计算相似度
|
# 使用 Minamo Model 计算相似度
|
||||||
|
graph = self.converter(pred, tau=self.tau)
|
||||||
|
pred_vision_feat, pred_topo_feat = self.minamo(pred_softmax, graph)
|
||||||
|
|
||||||
|
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
||||||
|
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
||||||
|
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
|
||||||
|
minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean()
|
||||||
|
|
||||||
print(
|
# print(
|
||||||
# structure_loss.item(),
|
# minamo_loss.item(),
|
||||||
border_loss.item(),
|
# border_loss.item(),
|
||||||
wall_loss.item(),
|
# wall_loss.item(),
|
||||||
entry_loss.item(),
|
# entry_loss.item(),
|
||||||
entry_dis_loss.item(),
|
# entry_dis_loss.item(),
|
||||||
enemy_loss.item(),
|
# enemy_loss.item(),
|
||||||
valid_block_loss.item(),
|
# valid_block_loss.item(),
|
||||||
count_loss.item()
|
# count_loss.item()
|
||||||
)
|
# )
|
||||||
|
|
||||||
return (
|
return (
|
||||||
# structure_loss * self.weight[0] +
|
minamo_loss * self.weight[0] +
|
||||||
border_loss * self.weight[1] +
|
border_loss * self.weight[1] +
|
||||||
wall_loss * self.weight[2] +
|
wall_loss * self.weight[2] +
|
||||||
entry_loss * self.weight[3] +
|
entry_loss * self.weight[3] +
|
||||||
|
|||||||
@ -14,8 +14,8 @@ class GumbelSoftmax(nn.Module):
|
|||||||
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
|
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
|
||||||
|
|
||||||
# 转换为类索引的连续表示
|
# 转换为类索引的连续表示
|
||||||
class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
# class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
||||||
return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
|
return y.argmax(dim=1) # 形状[BS, H, W]
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
class GinkaModel(nn.Module):
|
||||||
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||||
|
|||||||
@ -16,33 +16,25 @@ class GinkaEncoder(nn.Module):
|
|||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_res = self.conv(x) # 卷积提取特征
|
x_res = self.conv(x)
|
||||||
x_down = self.pool(x_res) # 进行池化
|
x_down = self.pool(x_res)
|
||||||
return x_down, x_res # 返回池化后的特征和跳跃连接特征
|
return x_down, x_res
|
||||||
|
|
||||||
class GinkaDecoder(nn.Module):
|
class GinkaDecoder(nn.Module):
|
||||||
"""解码器(上采样)部分"""
|
"""解码器(上采样)部分"""
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 上采样(双线性插值 + 卷积)
|
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||||
self.upsample = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
||||||
nn.BatchNorm2d(out_channels),
|
|
||||||
nn.ReLU()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 跳跃连接融合
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
|
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm2d(out_channels),
|
nn.BatchNorm2d(out_channels),
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, skip):
|
def forward(self, x, skip):
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
# 跳跃连接融合
|
x = torch.cat([x, skip], dim=1)
|
||||||
x = torch.cat([x, skip], dim=1)
|
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -66,26 +58,26 @@ class GinkaUNet(nn.Module):
|
|||||||
"""Ginka Model UNet 部分
|
"""Ginka Model UNet 部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.down1 = GinkaEncoder(in_ch, in_ch*2)
|
self.down1 = GinkaEncoder(in_ch, in_ch*2)
|
||||||
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
|
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
|
||||||
|
|
||||||
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
|
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
|
||||||
|
|
||||||
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
|
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
|
||||||
self.up2 = GinkaDecoder(in_ch*2, in_ch)
|
self.up2 = GinkaDecoder(in_ch*2, in_ch)
|
||||||
|
|
||||||
self.final = nn.Sequential(
|
self.final = nn.Sequential(
|
||||||
nn.Conv2d(in_ch, out_ch, 1)
|
nn.Conv2d(in_ch, out_ch, 1),
|
||||||
|
# nn.Softmax(dim=1) # 适用于分类任务
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, skip1 = self.down1(x)
|
x_down1, skip1 = self.down1(x)
|
||||||
x, skip2 = self.down2(x)
|
x_down2, skip2 = self.down2(x_down1)
|
||||||
|
|
||||||
x = self.bottleneck(x)
|
x = self.bottleneck(x_down2)
|
||||||
|
|
||||||
x = self.up1(x, skip2)
|
x = self.up1(x, skip2) # 用 down2 的 skip
|
||||||
x = self.up2(x, skip1)
|
x = self.up2(x, skip1) # 用 down1 的 skip
|
||||||
|
|
||||||
return self.final(x)
|
return self.final(x)
|
||||||
|
|||||||
@ -2,17 +2,17 @@ import os
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import BertTokenizer
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .model.model import GinkaModel
|
from .model.model import GinkaModel
|
||||||
from .model.loss import GinkaLoss
|
from .model.loss import GinkaLoss
|
||||||
from .dataset import GinkaDataset
|
from .dataset import GinkaDataset
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
|
from shared.graph import DynamicGraphConverter
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
|
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||||
|
|
||||||
epochs = 70
|
epochs = 70
|
||||||
|
|
||||||
@ -29,48 +29,85 @@ def train():
|
|||||||
minamo = MinamoModel(32)
|
minamo = MinamoModel(32)
|
||||||
minamo.to(device)
|
minamo.to(device)
|
||||||
minamo.eval()
|
minamo.eval()
|
||||||
|
|
||||||
|
converter = DynamicGraphConverter().to(device)
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
dataset = GinkaDataset("dataset.json", minamo)
|
dataset = GinkaDataset("ginka-dataset.json", device, minamo)
|
||||||
|
dataset_val = GinkaDataset("ginka-eval.json", device, minamo)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
|
dataloader_val = DataLoader(
|
||||||
|
dataset_val,
|
||||||
|
batch_size=32,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = GinkaLoss(minamo)
|
criterion = GinkaLoss(minamo, converter)
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(epochs)):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
model.softmax.tau = update_tau(epoch)
|
model.softmax.tau = update_tau(epoch)
|
||||||
|
criterion.tau = update_tau(epoch)
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
target = batch["target"].to(device)
|
target = batch["target"].to(device)
|
||||||
feat_vec = batch["feat_vec"].to(device)
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||||
# 前向传播
|
# 前向传播
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output, output_softmax = model(feat_vec)
|
output, output_softmax = model(feat_vec)
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
loss = criterion(output, output_softmax, target)
|
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
|
|
||||||
# 反向传播
|
# 反向传播
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
avg_loss = total_loss / len(dataloader)
|
||||||
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
# 学习率调整
|
# 学习率调整
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
if (epoch + 1) % 5 == 0:
|
||||||
|
loss_val = 0
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in dataloader_val:
|
||||||
|
# 数据迁移到设备
|
||||||
|
target = batch["target"].to(device)
|
||||||
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
output, output_softmax = model(feat_vec)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
|
loss_val += loss.item()
|
||||||
|
|
||||||
|
avg_val_loss = loss_val / len(dataloader_val)
|
||||||
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||||
|
torch.save({
|
||||||
|
"model_state": model.state_dict(),
|
||||||
|
"optimizer_state": optimizer.state_dict(),
|
||||||
|
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||||
|
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
|
|
||||||
torch.save({
|
torch.save({
|
||||||
|
|||||||
126
shared/graph.py
126
shared/graph.py
@ -1,5 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.data import Data, Batch
|
||||||
|
from torch_geometric.utils import dense_to_sparse
|
||||||
|
|
||||||
def convert_map_to_graph(map):
|
def convert_map_to_graph(map):
|
||||||
rows = len(map)
|
rows = len(map)
|
||||||
@ -26,4 +29,123 @@ def convert_map_to_graph(map):
|
|||||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||||||
|
|
||||||
return Data(x=node_features, edge_index=edge_index)
|
return Data(x=node_features, edge_index=edge_index)
|
||||||
|
|
||||||
|
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
|
||||||
|
"""
|
||||||
|
将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略
|
||||||
|
"""
|
||||||
|
B, H, W, C = map_tensor.shape
|
||||||
|
N = H * W
|
||||||
|
|
||||||
|
# 使用 Gumbel-Softmax 确保是 one-hot 形式
|
||||||
|
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
|
||||||
|
|
||||||
|
# 计算整数索引(用于 embedding)
|
||||||
|
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
|
||||||
|
|
||||||
|
# 取出墙体的 soft mask
|
||||||
|
wall_mask = y[:, :, 1] # [B, N]
|
||||||
|
|
||||||
|
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
|
||||||
|
|
||||||
|
def get_index(r, c):
|
||||||
|
return r * W + c
|
||||||
|
|
||||||
|
for r in range(H):
|
||||||
|
for c in range(W):
|
||||||
|
idx = get_index(r, c)
|
||||||
|
if c + 1 < W:
|
||||||
|
right_idx = get_index(r, c + 1)
|
||||||
|
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
|
||||||
|
if r + 1 < H:
|
||||||
|
down_idx = get_index(r + 1, c)
|
||||||
|
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
|
||||||
|
|
||||||
|
edge_index_list, edge_weight_list = [], []
|
||||||
|
for b in range(B):
|
||||||
|
adj_bin = (adjacency_matrix[b] > threshold).float()
|
||||||
|
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
|
||||||
|
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
|
||||||
|
|
||||||
|
edge_index_list.append(edge_index)
|
||||||
|
edge_weight_list.append(edge_weight)
|
||||||
|
|
||||||
|
edge_index = torch.cat(edge_index_list, dim=1)
|
||||||
|
edge_weight = torch.cat(edge_weight_list, dim=0)
|
||||||
|
|
||||||
|
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
|
||||||
|
|
||||||
|
class DynamicGraphConverter(nn.Module):
|
||||||
|
def __init__(self, map_size=13):
|
||||||
|
super().__init__()
|
||||||
|
self.map_size = map_size
|
||||||
|
self.n_nodes = map_size * map_size
|
||||||
|
|
||||||
|
# 预计算所有可能的边索引组合(包括对角线)
|
||||||
|
self.base_edge_index = self._precompute_base_edges()
|
||||||
|
|
||||||
|
def _precompute_base_edges(self):
|
||||||
|
"""预生成全连接边索引(包含所有可能邻接)"""
|
||||||
|
edge_list = []
|
||||||
|
directions = [
|
||||||
|
(0, 1), # 右
|
||||||
|
(1, 0), # 下
|
||||||
|
]
|
||||||
|
|
||||||
|
for r in range(self.map_size):
|
||||||
|
for c in range(self.map_size):
|
||||||
|
node = r * self.map_size + c
|
||||||
|
for dr, dc in directions:
|
||||||
|
nr, nc = r + dr, c + dc
|
||||||
|
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
||||||
|
neighbor = nr * self.map_size + nc
|
||||||
|
edge_list.append([node, neighbor])
|
||||||
|
|
||||||
|
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
||||||
|
|
||||||
|
def forward(self, map_probs, tau=0.5):
|
||||||
|
B, C, H, W = map_probs.shape
|
||||||
|
device = map_probs.device
|
||||||
|
|
||||||
|
self.base_edge_index = self.base_edge_index.to(device)
|
||||||
|
|
||||||
|
# 1. 节点特征离散化(保持可导)
|
||||||
|
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
||||||
|
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
||||||
|
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
|
||||||
|
|
||||||
|
# 2. 动态边权重计算
|
||||||
|
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
|
||||||
|
edge_weights = self._compute_dynamic_weights(wall_mask)
|
||||||
|
|
||||||
|
# 3. 构建动态图
|
||||||
|
batch_data = []
|
||||||
|
for b in range(B):
|
||||||
|
# 动态过滤无效边(与墙体相连的边)
|
||||||
|
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
|
||||||
|
dynamic_edge_index = self.base_edge_index[:, valid_mask]
|
||||||
|
dynamic_edge_attr = edge_weights[b][valid_mask]
|
||||||
|
|
||||||
|
data = Data(
|
||||||
|
x=node_ids[b],
|
||||||
|
edge_index=dynamic_edge_index,
|
||||||
|
edge_attr=dynamic_edge_attr
|
||||||
|
)
|
||||||
|
batch_data.append(data)
|
||||||
|
|
||||||
|
return Batch.from_data_list(batch_data)
|
||||||
|
|
||||||
|
def _compute_dynamic_weights(self, wall_mask):
|
||||||
|
"""基于墙体存在性计算动态边权重"""
|
||||||
|
# wall_mask: [B, N]
|
||||||
|
src_nodes = self.base_edge_index[0] # [E]
|
||||||
|
dst_nodes = self.base_edge_index[1] # [E]
|
||||||
|
|
||||||
|
# 边权重 = 1 - (源是墙 OR 目标墙)
|
||||||
|
weights = 1 - torch.logical_or(
|
||||||
|
wall_mask[:, src_nodes],
|
||||||
|
wall_mask[:, dst_nodes]
|
||||||
|
).float() # [B, E]
|
||||||
|
|
||||||
|
return weights.unsqueeze(-1) # [B, E, 1]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user