feat: 图卷积部分换用 TransformerConv

This commit is contained in:
unanmed 2025-05-09 13:07:03 +08:00
parent 7b138c66d9
commit 21b693ec21
6 changed files with 127 additions and 54 deletions

View File

@ -1,9 +1,18 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.utils import grid
def batch_edge_index(B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class DoubleConvBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
@ -41,7 +50,7 @@ class GCNBlock(nn.Module):
# Construct batched edge index
device = x.device
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
@ -57,15 +66,39 @@ class GCNBlock(nn.Module):
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class TransformerGCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):

37
ginka/generator/gcn.py Normal file
View File

@ -0,0 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.utils import grid
from ..common.cond import ConditionInjector
# 考虑使用 GCN 作为生成器主路径,暂时先留着
class GCNBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.conv1 = GCNConv(feats[0], feats[1])
self.conv2 = GCNConv(feats[1], feats[2])
self.norm1 = nn.LayerNorm(feats[1])
self.norm2 = nn.LayerNorm(feats[2])
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
return x
class GinkaGCNEncoder(nn.Module):
def __init__(self):
super().__init__()
class GinkaGCNDecoder(nn.Module):
def __init__(self):
super().__init__()
class GinkaGCNModel(nn.Module):
def __init__(self):
super().__init__()

View File

@ -249,7 +249,7 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]):
# weight: 判别器损失CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.attention import ChannelAttention
from ..common.common import GCNBlock
from ..common.common import GCNBlock, TransformerGCNBlock
from ..common.cond import ConditionInjector
class GinkaTransformerEncoder(nn.Module):
@ -65,7 +65,7 @@ class GinkaUNetInput(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, in_ch)
self.gcn = GCNBlock(in_ch, in_ch*2, in_ch, w, h)
self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h)
self.fusion = ConvBlock(in_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
@ -95,7 +95,7 @@ class GinkaGCNFusedEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.pool = nn.MaxPool2d(2)
self.fusion = FusionModule(out_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
@ -140,7 +140,7 @@ class GinkaGCNFusedDecoder(nn.Module):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.fusion = FusionModule(out_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
@ -156,26 +156,27 @@ class GinkaGCNFusedDecoder(nn.Module):
class GinkaBottleneck(nn.Module):
def __init__(self, module_ch, w, h):
super().__init__()
self.transformer = GinkaTransformerEncoder(
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
token_size=16, ff_dim=1024, num_layers=4
)
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
# self.conv = ConvBlock(module_ch, module_ch)
# self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, w, h)
# self.fusion = FusionModule(module_ch*2, module_ch)
# self.transformer = GinkaTransformerEncoder(
# in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
# token_size=16, ff_dim=1024, num_layers=4
# )
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
self.conv = ConvBlock(module_ch, module_ch)
self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h)
self.fusion = nn.Conv2d(module_ch*2, module_ch, 1)
self.inject = ConditionInjector(256, module_ch)
def forward(self, x, cond):
B = x.size(0)
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
x1 = self.transformer(x1)
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
# x1 = self.transformer(x1)
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x1 = self.conv(x)
x2 = self.gcn(x)
x = torch.cat([x, x1, x2], dim=1)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)

View File

@ -46,7 +46,7 @@ from shared.image import matrix_to_image_cv
# 29. 楼梯入口
# 30. 箭头入口
BATCH_SIZE = 16
BATCH_SIZE = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -350,14 +350,14 @@ def train():
else:
g_steps = 1
if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
g_steps += int(min(avg_loss_ginka * 5, 50))
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
# g_steps += int(min(avg_loss_ginka * 5, 50))
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio

View File

@ -1,6 +1,6 @@
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops
from torch_geometric.utils import add_self_loops, grid
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
"""
@ -14,38 +14,40 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
N = H * W
# 1. 节点特征
node_features = map_probs.view(C, -1).T # [N, C]
node_features = map_probs.view(C, H * W).T # [N, C]
edge_index, _ = grid(H, W)
edge_index = edge_index.to(device)
# 2. 构建所有可能的边连接
node_indices = torch.arange(N, device=device).view(H, W)
# node_indices = torch.arange(N, device=device).view(H, W)
# 水平连接(右邻居)
right_src = node_indices[:, :-1].flatten()
right_dst = node_indices[:, 1:].flatten()
# # 水平连接(右邻居)
# right_src = node_indices[:, :-1].flatten()
# right_dst = node_indices[:, 1:].flatten()
# 垂直连接(下邻居)
down_src = node_indices[:-1, :].flatten()
down_dst = node_indices[1:, :].flatten()
# # 垂直连接(下邻居)
# down_src = node_indices[:-1, :].flatten()
# down_dst = node_indices[1:, :].flatten()
# 合并边列表(双向)
edge_src = torch.cat([right_src, down_src])
edge_dst = torch.cat([right_dst, down_dst])
edge_index = torch.cat([
torch.stack([edge_src, edge_dst], dim=0),
torch.stack([edge_dst, edge_src], dim=0) # 反向连接
], dim=1).to(device, dtype=torch.long)
# # 合并边列表(双向)
# edge_src = torch.cat([right_src, down_src])
# edge_dst = torch.cat([right_dst, down_dst])
# edge_index = torch.cat([
# torch.stack([edge_src, edge_dst], dim=0),
# torch.stack([edge_dst, edge_src], dim=0) # 反向连接
# ], dim=1).to(device, dtype=torch.long)
# 3. 计算边特征
src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
edge_attr = (src_feat + dst_feat) / 2 # [E, C]
# # 3. 计算边特征
# src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
# dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
# edge_attr = (src_feat + dst_feat) / 2 # [E, C]
edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
# edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
return Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_attr,
# edge_attr=edge_attr,
num_nodes=N
)