diff --git a/ginka/common/common.py b/ginka/common/common.py index 681ec46..59e70b0 100644 --- a/ginka/common/common.py +++ b/ginka/common/common.py @@ -1,9 +1,18 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv +from torch_geometric.nn import GCNConv, TransformerConv from torch_geometric.utils import grid +def batch_edge_index(B, edge_index, num_nodes_per_batch): + # 批次偏移 edge_index + edge_index = edge_index.clone() # [2, E] + batch_edge_index = [] + for i in range(B): + offset = i * num_nodes_per_batch + batch_edge_index.append(edge_index + offset) + return torch.cat(batch_edge_index, dim=1) + class DoubleConvBlock(nn.Module): def __init__(self, feats: tuple[int, int, int]): super().__init__() @@ -41,7 +50,7 @@ class GCNBlock(nn.Module): # Construct batched edge index device = x.device - edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W) + edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W) # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) # batch = torch.arange(B, device=device).repeat_interleave(H * W) @@ -57,15 +66,39 @@ class GCNBlock(nn.Module): # Reshape back to [B, C, H, W] x = x.view(B, H, W, -1).permute(0, 3, 1, 2) return x + +class TransformerGCNBlock(nn.Module): + def __init__(self, in_ch, hidden_ch, out_ch, w, h): + super().__init__() + self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True) + self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1) + self.norm1 = nn.LayerNorm(hidden_ch) + self.norm2 = nn.LayerNorm(out_ch) + self.single_edge_index, _ = grid(h, w) # [2, E] for a single map - def _batch_edge_index(self, B, edge_index, num_nodes_per_batch): - # 批次偏移 edge_index - edge_index = edge_index.clone() # [2, E] - batch_edge_index = [] - for i in range(B): - offset = i * num_nodes_per_batch - batch_edge_index.append(edge_index + offset) - return torch.cat(batch_edge_index, dim=1) + def forward(self, x): + # x: [B, C, H, W] + B, C, H, W = x.shape + + # Reshape to [B * H * W, C] + x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) + + # Construct batched edge index + device = x.device + edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W) + + # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) + # batch = torch.arange(B, device=device).repeat_interleave(H * W) + + # GCN forward + x = self.conv1(x, edge_index) + x = F.elu(self.norm1(x)) + x = self.conv2(x, edge_index) + x = F.elu(self.norm2(x)) + + # Reshape back to [B, C, H, W] + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + return x class ConvFusionModule(nn.Module): def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): diff --git a/ginka/generator/gcn.py b/ginka/generator/gcn.py new file mode 100644 index 0000000..cb4d24a --- /dev/null +++ b/ginka/generator/gcn.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, TransformerConv +from torch_geometric.utils import grid +from ..common.cond import ConditionInjector + +# 考虑使用 GCN 作为生成器主路径,暂时先留着 + +class GCNBlock(nn.Module): + def __init__(self, feats: tuple[int, int, int]): + super().__init__() + self.conv1 = GCNConv(feats[0], feats[1]) + self.conv2 = GCNConv(feats[1], feats[2]) + + self.norm1 = nn.LayerNorm(feats[1]) + self.norm2 = nn.LayerNorm(feats[2]) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.elu(self.norm1(x)) + + x = self.conv2(x, edge_index) + x = F.elu(self.norm2(x)) + return x + +class GinkaGCNEncoder(nn.Module): + def __init__(self): + super().__init__() + +class GinkaGCNDecoder(nn.Module): + def __init__(self): + super().__init__() + +class GinkaGCNModel(nn.Module): + def __init__(self): + super().__init__() \ No newline at end of file diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 91046bd..b05d917 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -249,7 +249,7 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]): + def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.05, 0.5]): # weight: 判别器损失,CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight diff --git a/ginka/generator/unet.py b/ginka/generator/unet.py index a769df5..839b2e9 100644 --- a/ginka/generator/unet.py +++ b/ginka/generator/unet.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from shared.attention import ChannelAttention -from ..common.common import GCNBlock +from ..common.common import GCNBlock, TransformerGCNBlock from ..common.cond import ConditionInjector class GinkaTransformerEncoder(nn.Module): @@ -65,7 +65,7 @@ class GinkaUNetInput(nn.Module): def __init__(self, in_ch, out_ch, w, h): super().__init__() self.conv = ConvBlock(in_ch, in_ch) - self.gcn = GCNBlock(in_ch, in_ch*2, in_ch, w, h) + self.gcn = TransformerGCNBlock(in_ch, in_ch*2, in_ch, w, h) self.fusion = ConvBlock(in_ch*2, out_ch) self.inject = ConditionInjector(256, out_ch) @@ -95,7 +95,7 @@ class GinkaGCNFusedEncoder(nn.Module): def __init__(self, in_ch, out_ch, w, h): super().__init__() self.conv = ConvBlock(in_ch, out_ch) - self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h) + self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h) self.pool = nn.MaxPool2d(2) self.fusion = FusionModule(out_ch*2, out_ch) self.inject = ConditionInjector(256, out_ch) @@ -140,7 +140,7 @@ class GinkaGCNFusedDecoder(nn.Module): super().__init__() self.upsample = GinkaUpSample(in_ch, in_ch // 2) self.conv = ConvBlock(in_ch, out_ch) - self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h) + self.gcn = TransformerGCNBlock(out_ch, out_ch*2, out_ch, w, h) self.fusion = FusionModule(out_ch*2, out_ch) self.inject = ConditionInjector(256, out_ch) @@ -156,26 +156,27 @@ class GinkaGCNFusedDecoder(nn.Module): class GinkaBottleneck(nn.Module): def __init__(self, module_ch, w, h): super().__init__() - self.transformer = GinkaTransformerEncoder( - in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, - token_size=16, ff_dim=1024, num_layers=4 - ) - self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4) - self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) - # self.conv = ConvBlock(module_ch, module_ch) - # self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, w, h) - # self.fusion = FusionModule(module_ch*2, module_ch) + # self.transformer = GinkaTransformerEncoder( + # in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, + # token_size=16, ff_dim=1024, num_layers=4 + # ) + # self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4) + # self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) + self.conv = ConvBlock(module_ch, module_ch) + self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, w, h) + self.fusion = nn.Conv2d(module_ch*2, module_ch, 1) self.inject = ConditionInjector(256, module_ch) def forward(self, x, cond): B = x.size(0) - x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] - x1 = self.transformer(x1) - x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] + # x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] + # x1 = self.transformer(x1) + # x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] + x1 = self.conv(x) x2 = self.gcn(x) - x = torch.cat([x, x1, x2], dim=1) + x = torch.cat([x1, x2], dim=1) x = self.fusion(x) x = self.inject(x, cond) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index dd4b54b..9f63e6f 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -46,7 +46,7 @@ from shared.image import matrix_to_image_cv # 29. 楼梯入口 # 30. 箭头入口 -BATCH_SIZE = 16 +BATCH_SIZE = 8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -350,14 +350,14 @@ def train(): else: g_steps = 1 - if avg_loss_ginka > 0 and epoch > 20 and not args.resume: - g_steps += int(min(avg_loss_ginka * 5, 50)) + # if avg_loss_ginka > 0 and epoch > 20 and not args.resume: + # g_steps += int(min(avg_loss_ginka * 5, 50)) if avg_loss_minamo > 0: c_steps = int(min(5 + avg_loss_minamo * 5, 15)) else: c_steps = 5 - + dataset.train_stage = train_stage dataset_val.train_stage = train_stage dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio diff --git a/shared/graph.py b/shared/graph.py index c109242..1eb0aa6 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,6 +1,6 @@ import torch from torch_geometric.data import Data, Batch -from torch_geometric.utils import add_self_loops +from torch_geometric.utils import add_self_loops, grid def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: """ @@ -14,38 +14,40 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: N = H * W # 1. 节点特征 - node_features = map_probs.view(C, -1).T # [N, C] + node_features = map_probs.view(C, H * W).T # [N, C] + edge_index, _ = grid(H, W) + edge_index = edge_index.to(device) # 2. 构建所有可能的边连接 - node_indices = torch.arange(N, device=device).view(H, W) + # node_indices = torch.arange(N, device=device).view(H, W) - # 水平连接(右邻居) - right_src = node_indices[:, :-1].flatten() - right_dst = node_indices[:, 1:].flatten() + # # 水平连接(右邻居) + # right_src = node_indices[:, :-1].flatten() + # right_dst = node_indices[:, 1:].flatten() - # 垂直连接(下邻居) - down_src = node_indices[:-1, :].flatten() - down_dst = node_indices[1:, :].flatten() + # # 垂直连接(下邻居) + # down_src = node_indices[:-1, :].flatten() + # down_dst = node_indices[1:, :].flatten() - # 合并边列表(双向) - edge_src = torch.cat([right_src, down_src]) - edge_dst = torch.cat([right_dst, down_dst]) - edge_index = torch.cat([ - torch.stack([edge_src, edge_dst], dim=0), - torch.stack([edge_dst, edge_src], dim=0) # 反向连接 - ], dim=1).to(device, dtype=torch.long) + # # 合并边列表(双向) + # edge_src = torch.cat([right_src, down_src]) + # edge_dst = torch.cat([right_dst, down_dst]) + # edge_index = torch.cat([ + # torch.stack([edge_src, edge_dst], dim=0), + # torch.stack([edge_dst, edge_src], dim=0) # 反向连接 + # ], dim=1).to(device, dtype=torch.long) - # 3. 计算边特征 - src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C] - dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C] - edge_attr = (src_feat + dst_feat) / 2 # [E, C] + # # 3. 计算边特征 + # src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C] + # dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C] + # edge_attr = (src_feat + dst_feat) / 2 # [E, C] - edge_index, edge_attr = add_self_loops(edge_index, edge_attr) + # edge_index, edge_attr = add_self_loops(edge_index, edge_attr) return Data( x=node_features, edge_index=edge_index, - edge_attr=edge_attr, + # edge_attr=edge_attr, num_nodes=N )