mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: GINKA 生成器使用 Minamo 作为损失值的一部分
This commit is contained in:
parent
1566acf691
commit
09c63fedce
@ -15,10 +15,11 @@ def load_data(path: str):
|
||||
return data_list
|
||||
|
||||
class GinkaDataset(Dataset):
|
||||
def __init__(self, data_path: str, minamo: MinamoModel):
|
||||
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
self.max_size = 32
|
||||
self.minamo = minamo
|
||||
self.device = device
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -26,13 +27,13 @@ class GinkaDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
target = torch.tensor(item["map"])
|
||||
graph = convert_map_to_graph(target)
|
||||
vision_feat, topo_feat = self.minamo(target, graph)
|
||||
feat_vec = torch.cat([vision_feat, topo_feat])
|
||||
target = torch.tensor(item["map"]).to(self.device)
|
||||
graph = convert_map_to_graph(target).to(self.device)
|
||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||
|
||||
return {
|
||||
"feat_vec": feat_vec,
|
||||
"target_vision_feat": vision_feat,
|
||||
"target_topo_feat": topo_feat,
|
||||
"target": target
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import DynamicGraphConverter
|
||||
|
||||
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
||||
"""地图最外层是否为墙"""
|
||||
@ -131,7 +132,6 @@ def entrance_distance_and_presence_loss(
|
||||
total_loss: 综合入口距离与存在性损失
|
||||
"""
|
||||
# 将 logits 转换为概率分布
|
||||
probs = F.softmax(logits, dim=1) # [B, C, H, W]
|
||||
B, C, H, W = logits.shape
|
||||
|
||||
# 提取箭头和楼梯的概率图
|
||||
@ -147,9 +147,9 @@ def entrance_distance_and_presence_loss(
|
||||
arrow_distance_loss = F.relu(arrow_excess).mean()
|
||||
|
||||
# 楼梯:使用窗口大小为 (W//2, H//2)
|
||||
kernel_size_stairs = (max(1, W // 2), max(1, H // 2))
|
||||
kernel_size_stairs = (9, 9)
|
||||
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device)
|
||||
pad_stairs = (kernel_size_stairs[0] // 2, kernel_size_stairs[1] // 2)
|
||||
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
|
||||
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
|
||||
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
|
||||
stairs_distance_loss = F.relu(stairs_excess).mean()
|
||||
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
|
||||
return avg_loss
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||
def __init__(self, minamo: MinamoModel, converter: DynamicGraphConverter, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||
"""Ginka Model 损失函数部分
|
||||
|
||||
Args:
|
||||
@ -301,8 +301,10 @@ class GinkaLoss(nn.Module):
|
||||
self.weight = weight
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
self.minamo = minamo
|
||||
self.tau = 1
|
||||
self.converter = converter
|
||||
|
||||
def forward(self, pred, pred_softmax, target):
|
||||
def forward(self, pred, pred_softmax, target, target_vision_feat, target_topo_feat):
|
||||
probs = F.softmax(pred, dim=1)
|
||||
# 地图结构损失
|
||||
border_loss = wall_border_loss(pred, probs)
|
||||
@ -314,21 +316,27 @@ class GinkaLoss(nn.Module):
|
||||
count_loss = integrated_count_loss(probs, target)
|
||||
|
||||
# 使用 Minamo Model 计算相似度
|
||||
graph = self.converter(pred, tau=self.tau)
|
||||
pred_vision_feat, pred_topo_feat = self.minamo(pred_softmax, graph)
|
||||
|
||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
||||
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
|
||||
minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean()
|
||||
|
||||
print(
|
||||
# structure_loss.item(),
|
||||
border_loss.item(),
|
||||
wall_loss.item(),
|
||||
entry_loss.item(),
|
||||
entry_dis_loss.item(),
|
||||
enemy_loss.item(),
|
||||
valid_block_loss.item(),
|
||||
count_loss.item()
|
||||
)
|
||||
# print(
|
||||
# minamo_loss.item(),
|
||||
# border_loss.item(),
|
||||
# wall_loss.item(),
|
||||
# entry_loss.item(),
|
||||
# entry_dis_loss.item(),
|
||||
# enemy_loss.item(),
|
||||
# valid_block_loss.item(),
|
||||
# count_loss.item()
|
||||
# )
|
||||
|
||||
return (
|
||||
# structure_loss * self.weight[0] +
|
||||
minamo_loss * self.weight[0] +
|
||||
border_loss * self.weight[1] +
|
||||
wall_loss * self.weight[2] +
|
||||
entry_loss * self.weight[3] +
|
||||
|
||||
@ -14,8 +14,8 @@ class GumbelSoftmax(nn.Module):
|
||||
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
|
||||
|
||||
# 转换为类索引的连续表示
|
||||
class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
||||
return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
|
||||
# class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
||||
return y.argmax(dim=1) # 形状[BS, H, W]
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||
|
||||
@ -16,33 +16,25 @@ class GinkaEncoder(nn.Module):
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
x_res = self.conv(x) # 卷积提取特征
|
||||
x_down = self.pool(x_res) # 进行池化
|
||||
return x_down, x_res # 返回池化后的特征和跳跃连接特征
|
||||
x_res = self.conv(x)
|
||||
x_down = self.pool(x_res)
|
||||
return x_down, x_res
|
||||
|
||||
class GinkaDecoder(nn.Module):
|
||||
"""解码器(上采样)部分"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
# 上采样(双线性插值 + 卷积)
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# 跳跃连接融合
|
||||
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
|
||||
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.upsample(x)
|
||||
# 跳跃连接融合
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
@ -66,26 +58,26 @@ class GinkaUNet(nn.Module):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.down1 = GinkaEncoder(in_ch, in_ch*2)
|
||||
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
|
||||
|
||||
|
||||
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
|
||||
|
||||
|
||||
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
|
||||
self.up2 = GinkaDecoder(in_ch*2, in_ch)
|
||||
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
nn.Conv2d(in_ch, out_ch, 1),
|
||||
# nn.Softmax(dim=1) # 适用于分类任务
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x, skip1 = self.down1(x)
|
||||
x, skip2 = self.down2(x)
|
||||
|
||||
x = self.bottleneck(x)
|
||||
|
||||
x = self.up1(x, skip2)
|
||||
x = self.up2(x, skip1)
|
||||
x_down1, skip1 = self.down1(x)
|
||||
x_down2, skip2 = self.down2(x_down1)
|
||||
|
||||
x = self.bottleneck(x_down2)
|
||||
|
||||
x = self.up1(x, skip2) # 用 down2 的 skip
|
||||
x = self.up2(x, skip1) # 用 down1 的 skip
|
||||
|
||||
return self.final(x)
|
||||
|
||||
@ -2,17 +2,17 @@ import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BertTokenizer
|
||||
from tqdm import tqdm
|
||||
from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss
|
||||
from .dataset import GinkaDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import DynamicGraphConverter
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 70
|
||||
|
||||
@ -29,48 +29,85 @@ def train():
|
||||
minamo = MinamoModel(32)
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
converter = DynamicGraphConverter().to(device)
|
||||
|
||||
# 准备数据集
|
||||
dataset = GinkaDataset("dataset.json", minamo)
|
||||
dataset = GinkaDataset("ginka-dataset.json", device, minamo)
|
||||
dataset_val = GinkaDataset("ginka-eval.json", device, minamo)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
shuffle=True
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val,
|
||||
batch_size=32,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
criterion = GinkaLoss(minamo, converter)
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
model.softmax.tau = update_tau(epoch)
|
||||
criterion.tau = update_tau(epoch)
|
||||
|
||||
for batch in dataloader:
|
||||
# 数据迁移到设备
|
||||
target = batch["target"].to(device)
|
||||
feat_vec = batch["feat_vec"].to(device)
|
||||
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
output, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, output_softmax, target)
|
||||
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler.step()
|
||||
|
||||
if (epoch + 1) % 5 == 0:
|
||||
loss_val = 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader_val:
|
||||
# 数据迁移到设备
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += loss.item()
|
||||
|
||||
avg_val_loss = loss_val / len(dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
torch.save({
|
||||
|
||||
126
shared/graph.py
126
shared/graph.py
@ -1,5 +1,8 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.data import Data, Batch
|
||||
from torch_geometric.utils import dense_to_sparse
|
||||
|
||||
def convert_map_to_graph(map):
|
||||
rows = len(map)
|
||||
@ -26,4 +29,123 @@ def convert_map_to_graph(map):
|
||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index)
|
||||
return Data(x=node_features, edge_index=edge_index)
|
||||
|
||||
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
|
||||
"""
|
||||
将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略
|
||||
"""
|
||||
B, H, W, C = map_tensor.shape
|
||||
N = H * W
|
||||
|
||||
# 使用 Gumbel-Softmax 确保是 one-hot 形式
|
||||
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
|
||||
|
||||
# 计算整数索引(用于 embedding)
|
||||
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
|
||||
|
||||
# 取出墙体的 soft mask
|
||||
wall_mask = y[:, :, 1] # [B, N]
|
||||
|
||||
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
|
||||
|
||||
def get_index(r, c):
|
||||
return r * W + c
|
||||
|
||||
for r in range(H):
|
||||
for c in range(W):
|
||||
idx = get_index(r, c)
|
||||
if c + 1 < W:
|
||||
right_idx = get_index(r, c + 1)
|
||||
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
|
||||
if r + 1 < H:
|
||||
down_idx = get_index(r + 1, c)
|
||||
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
|
||||
|
||||
edge_index_list, edge_weight_list = [], []
|
||||
for b in range(B):
|
||||
adj_bin = (adjacency_matrix[b] > threshold).float()
|
||||
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
|
||||
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
|
||||
|
||||
edge_index_list.append(edge_index)
|
||||
edge_weight_list.append(edge_weight)
|
||||
|
||||
edge_index = torch.cat(edge_index_list, dim=1)
|
||||
edge_weight = torch.cat(edge_weight_list, dim=0)
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
|
||||
|
||||
class DynamicGraphConverter(nn.Module):
|
||||
def __init__(self, map_size=13):
|
||||
super().__init__()
|
||||
self.map_size = map_size
|
||||
self.n_nodes = map_size * map_size
|
||||
|
||||
# 预计算所有可能的边索引组合(包括对角线)
|
||||
self.base_edge_index = self._precompute_base_edges()
|
||||
|
||||
def _precompute_base_edges(self):
|
||||
"""预生成全连接边索引(包含所有可能邻接)"""
|
||||
edge_list = []
|
||||
directions = [
|
||||
(0, 1), # 右
|
||||
(1, 0), # 下
|
||||
]
|
||||
|
||||
for r in range(self.map_size):
|
||||
for c in range(self.map_size):
|
||||
node = r * self.map_size + c
|
||||
for dr, dc in directions:
|
||||
nr, nc = r + dr, c + dc
|
||||
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
||||
neighbor = nr * self.map_size + nc
|
||||
edge_list.append([node, neighbor])
|
||||
|
||||
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
||||
|
||||
def forward(self, map_probs, tau=0.5):
|
||||
B, C, H, W = map_probs.shape
|
||||
device = map_probs.device
|
||||
|
||||
self.base_edge_index = self.base_edge_index.to(device)
|
||||
|
||||
# 1. 节点特征离散化(保持可导)
|
||||
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
||||
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
||||
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
|
||||
|
||||
# 2. 动态边权重计算
|
||||
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
|
||||
edge_weights = self._compute_dynamic_weights(wall_mask)
|
||||
|
||||
# 3. 构建动态图
|
||||
batch_data = []
|
||||
for b in range(B):
|
||||
# 动态过滤无效边(与墙体相连的边)
|
||||
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
|
||||
dynamic_edge_index = self.base_edge_index[:, valid_mask]
|
||||
dynamic_edge_attr = edge_weights[b][valid_mask]
|
||||
|
||||
data = Data(
|
||||
x=node_ids[b],
|
||||
edge_index=dynamic_edge_index,
|
||||
edge_attr=dynamic_edge_attr
|
||||
)
|
||||
batch_data.append(data)
|
||||
|
||||
return Batch.from_data_list(batch_data)
|
||||
|
||||
def _compute_dynamic_weights(self, wall_mask):
|
||||
"""基于墙体存在性计算动态边权重"""
|
||||
# wall_mask: [B, N]
|
||||
src_nodes = self.base_edge_index[0] # [E]
|
||||
dst_nodes = self.base_edge_index[1] # [E]
|
||||
|
||||
# 边权重 = 1 - (源是墙 OR 目标墙)
|
||||
weights = 1 - torch.logical_or(
|
||||
wall_mask[:, src_nodes],
|
||||
wall_mask[:, dst_nodes]
|
||||
).float() # [B, E]
|
||||
|
||||
return weights.unsqueeze(-1) # [B, E, 1]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user