mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 图卷积部分换用 TransformerConv
This commit is contained in:
parent
7b138c66d9
commit
21b693ec21
@ -1,9 +1,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.nn import GCNConv, TransformerConv
|
||||
from torch_geometric.utils import grid
|
||||
|
||||
def batch_edge_index(B, edge_index, num_nodes_per_batch):
|
||||
# 批次偏移 edge_index
|
||||
edge_index = edge_index.clone() # [2, E]
|
||||
batch_edge_index = []
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
|
||||
class DoubleConvBlock(nn.Module):
|
||||
def __init__(self, feats: tuple[int, int, int]):
|
||||
super().__init__()
|
||||
@ -41,7 +50,7 @@ class GCNBlock(nn.Module):
|
||||
|
||||
# Construct batched edge index
|
||||
device = x.device
|
||||
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
|
||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
||||
@ -57,15 +66,39 @@ class GCNBlock(nn.Module):
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
class TransformerGCNBlock(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
|
||||
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
||||
self.norm2 = nn.LayerNorm(out_ch)
|
||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
||||
|
||||
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
|
||||
# 批次偏移 edge_index
|
||||
edge_index = edge_index.clone() # [2, E]
|
||||
batch_edge_index = []
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
def forward(self, x):
|
||||
# x: [B, C, H, W]
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Reshape to [B * H * W, C]
|
||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
||||
|
||||
# Construct batched edge index
|
||||
device = x.device
|
||||
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
|
||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
||||
|
||||
# GCN forward
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
class ConvFusionModule(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
|
||||
|
||||
37
ginka/generator/gcn.py
Normal file
37
ginka/generator/gcn.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, TransformerConv
|
||||
from torch_geometric.utils import grid
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
# 考虑使用 GCN 作为生成器主路径,暂时先留着
|
||||
|
||||
class GCNBlock(nn.Module):
|
||||
def __init__(self, feats: tuple[int, int, int]):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(feats[0], feats[1])
|
||||
self.conv2 = GCNConv(feats[1], feats[2])
|
||||
|
||||
self.norm1 = nn.LayerNorm(feats[1])
|
||||
self.norm2 = nn.LayerNorm(feats[2])
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
return x
|
||||
|
||||
class GinkaGCNEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class GinkaGCNDecoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class GinkaGCNModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -249,7 +249,7 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from shared.attention import ChannelAttention
|
||||
from ..common.common import GCNBlock
|
||||
from ..common.common import GCNBlock, TransformerGCNBlock
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class GinkaTransformerEncoder(nn.Module):
|
||||
@ -65,7 +65,7 @@ class GinkaUNetInput(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, in_ch)
|
||||
self.gcn = GCNBlock(in_ch, in_ch*2, in_ch, w, h)
|
||||
self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h)
|
||||
self.fusion = ConvBlock(in_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
@ -95,7 +95,7 @@ class GinkaGCNFusedEncoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
@ -140,7 +140,7 @@ class GinkaGCNFusedDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
@ -156,26 +156,27 @@ class GinkaGCNFusedDecoder(nn.Module):
|
||||
class GinkaBottleneck(nn.Module):
|
||||
def __init__(self, module_ch, w, h):
|
||||
super().__init__()
|
||||
self.transformer = GinkaTransformerEncoder(
|
||||
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
||||
token_size=16, ff_dim=1024, num_layers=4
|
||||
)
|
||||
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
||||
# self.conv = ConvBlock(module_ch, module_ch)
|
||||
# self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, w, h)
|
||||
# self.fusion = FusionModule(module_ch*2, module_ch)
|
||||
# self.transformer = GinkaTransformerEncoder(
|
||||
# in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
||||
# token_size=16, ff_dim=1024, num_layers=4
|
||||
# )
|
||||
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
||||
self.conv = ConvBlock(module_ch, module_ch)
|
||||
self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h)
|
||||
self.fusion = nn.Conv2d(module_ch*2, module_ch, 1)
|
||||
self.inject = ConditionInjector(256, module_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
B = x.size(0)
|
||||
|
||||
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
||||
x1 = self.transformer(x1)
|
||||
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
||||
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
||||
# x1 = self.transformer(x1)
|
||||
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
||||
x1 = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
|
||||
x = torch.cat([x, x1, x2], dim=1)
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.inject(x, cond)
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ from shared.image import matrix_to_image_cv
|
||||
# 29. 楼梯入口
|
||||
# 30. 箭头入口
|
||||
|
||||
BATCH_SIZE = 16
|
||||
BATCH_SIZE = 8
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -350,14 +350,14 @@ def train():
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
|
||||
g_steps += int(min(avg_loss_ginka * 5, 50))
|
||||
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
|
||||
# g_steps += int(min(avg_loss_ginka * 5, 50))
|
||||
|
||||
if avg_loss_minamo > 0:
|
||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
else:
|
||||
c_steps = 5
|
||||
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
from torch_geometric.utils import add_self_loops
|
||||
from torch_geometric.utils import add_self_loops, grid
|
||||
|
||||
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
"""
|
||||
@ -14,38 +14,40 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
N = H * W
|
||||
|
||||
# 1. 节点特征
|
||||
node_features = map_probs.view(C, -1).T # [N, C]
|
||||
node_features = map_probs.view(C, H * W).T # [N, C]
|
||||
edge_index, _ = grid(H, W)
|
||||
edge_index = edge_index.to(device)
|
||||
|
||||
# 2. 构建所有可能的边连接
|
||||
node_indices = torch.arange(N, device=device).view(H, W)
|
||||
# node_indices = torch.arange(N, device=device).view(H, W)
|
||||
|
||||
# 水平连接(右邻居)
|
||||
right_src = node_indices[:, :-1].flatten()
|
||||
right_dst = node_indices[:, 1:].flatten()
|
||||
# # 水平连接(右邻居)
|
||||
# right_src = node_indices[:, :-1].flatten()
|
||||
# right_dst = node_indices[:, 1:].flatten()
|
||||
|
||||
# 垂直连接(下邻居)
|
||||
down_src = node_indices[:-1, :].flatten()
|
||||
down_dst = node_indices[1:, :].flatten()
|
||||
# # 垂直连接(下邻居)
|
||||
# down_src = node_indices[:-1, :].flatten()
|
||||
# down_dst = node_indices[1:, :].flatten()
|
||||
|
||||
# 合并边列表(双向)
|
||||
edge_src = torch.cat([right_src, down_src])
|
||||
edge_dst = torch.cat([right_dst, down_dst])
|
||||
edge_index = torch.cat([
|
||||
torch.stack([edge_src, edge_dst], dim=0),
|
||||
torch.stack([edge_dst, edge_src], dim=0) # 反向连接
|
||||
], dim=1).to(device, dtype=torch.long)
|
||||
# # 合并边列表(双向)
|
||||
# edge_src = torch.cat([right_src, down_src])
|
||||
# edge_dst = torch.cat([right_dst, down_dst])
|
||||
# edge_index = torch.cat([
|
||||
# torch.stack([edge_src, edge_dst], dim=0),
|
||||
# torch.stack([edge_dst, edge_src], dim=0) # 反向连接
|
||||
# ], dim=1).to(device, dtype=torch.long)
|
||||
|
||||
# 3. 计算边特征
|
||||
src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
|
||||
dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
|
||||
edge_attr = (src_feat + dst_feat) / 2 # [E, C]
|
||||
# # 3. 计算边特征
|
||||
# src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
|
||||
# dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
|
||||
# edge_attr = (src_feat + dst_feat) / 2 # [E, C]
|
||||
|
||||
edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
|
||||
# edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
|
||||
|
||||
return Data(
|
||||
x=node_features,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
# edge_attr=edge_attr,
|
||||
num_nodes=N
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user