From 22a2db464f64a7a34e59c411549922415d979287 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 16:27:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20maskGIT=20=E5=8A=A0=E5=85=A5=E7=83=AD?= =?UTF-8?q?=E5=8A=9B=E5=9B=BE=E6=9D=A1=E4=BB=B6=E9=99=90=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 24 +++++++- ginka/maskGIT/cond.py | 58 +++++++++++++++++++ ginka/{transformer => maskGIT}/mask.py | 0 .../maskGIT.py => maskGIT/model.py} | 29 ++++++---- ginka/train_maskGIT.py | 12 ++-- 5 files changed, 104 insertions(+), 19 deletions(-) create mode 100644 ginka/maskGIT/cond.py rename ginka/{transformer => maskGIT}/mask.py (100%) rename ginka/{transformer/maskGIT.py => maskGIT/model.py} (67%) diff --git a/ginka/dataset.py b/ginka/dataset.py index 6d7d463..b4a6873 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -237,4 +237,26 @@ class GinkaRNNDataset(Dataset): "tag_cond": tag_cond, "val_cond": val_cond, "target_map": target - } \ No newline at end of file + } + +class GinkaMaskGITDataset(Dataset): + def __init__(self, data_path: str, device): + self.data = load_data(data_path) # 自定义数据加载函数 + self.device = device + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + target = torch.LongTensor(item['map']) # [H, W] + cond = torch.FloatTensor(item['val']) # [cond_dim] + heatmap = torch.FloatTensor(item['heatmap']) # [heatmap_channel, H, W] + + return { + "cond": cond, + "target_map": target, + "heatmap": heatmap + } + \ No newline at end of file diff --git a/ginka/maskGIT/cond.py b/ginka/maskGIT/cond.py new file mode 100644 index 0000000..292853f --- /dev/null +++ b/ginka/maskGIT/cond.py @@ -0,0 +1,58 @@ +import time +import torch +import torch.nn as nn +from ..utils import print_memory + +class GinkaMaskGITCond(nn.Module): + def __init__(self, cond_dim=16, heatmap_channel=4, output_dim=256): + super().__init__() + self.cond_fc = nn.Sequential( + nn.Linear(cond_dim, output_dim // 2), + nn.LayerNorm(output_dim // 2), + nn.ReLU(), + + nn.Linear(output_dim // 2, output_dim) + ) + + self.heatmap_conv = nn.Sequential( + nn.Conv2d(heatmap_channel, output_dim // 4, kernel_size=3, padding=1, padding_mode='replicate'), + nn.BatchNorm2d(output_dim // 4), + nn.ReLU(), + + nn.Conv2d(output_dim // 4, output_dim // 2, kernel_size=3, padding=1, padding_mode='replicate'), + nn.BatchNorm2d(output_dim // 2), + nn.ReLU(), + + nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, padding=1, padding_mode='replicate') + ) + + def forward(self, cond, heatmap): + # cond: [B, cond_dim] + # heatmap: [B, C, H, W] + cond = self.cond_fc(cond) + heatmap = self.heatmap_conv(heatmap) + return cond, heatmap + +if __name__ == "__main__": + device = torch.device("cpu") + + cond = torch.rand(1, 16).to(device) + heatmap = torch.rand(1, 4, 13, 13).to(device) + + # 初始化模型 + model = GinkaMaskGITCond().to(device) + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + cond, heatmap = model(cond, heatmap) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: cond={cond.shape}, heatmap={heatmap.shape}") + print(f"Cond FC parameters: {sum(p.numel() for p in model.cond_fc.parameters())}") + print(f"Heatmap Conv parameters: {sum(p.numel() for p in model.heatmap_conv.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/transformer/mask.py b/ginka/maskGIT/mask.py similarity index 100% rename from ginka/transformer/mask.py rename to ginka/maskGIT/mask.py diff --git a/ginka/transformer/maskGIT.py b/ginka/maskGIT/model.py similarity index 67% rename from ginka/transformer/maskGIT.py rename to ginka/maskGIT/model.py index 59f4de7..253e26e 100644 --- a/ginka/transformer/maskGIT.py +++ b/ginka/maskGIT/model.py @@ -2,19 +2,19 @@ import time import torch import torch.nn as nn from ..utils import print_memory +from .cond import GinkaMaskGITCond class GinkaMaskGIT(nn.Module): def __init__( - self, num_classes=16, cond_dim=16, d_model=256, dim_ff=512, nhead=8, num_layers=4, map_size=13*13 + self, num_classes=16, cond_dim=16, heatmap_channel=4, d_model=256, + dim_ff=512, nhead=8, num_layers=4, map_size=13*13 ): super().__init__() self.tile_embedding = nn.Embedding(num_classes, d_model) - self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model)) + self.pos_embedding = nn.Parameter(torch.randn(1, map_size + 1, d_model)) - self.cond_projection = nn.Sequential( - nn.Linear(cond_dim, d_model) - ) + self.cond_encoder = GinkaMaskGITCond(cond_dim=cond_dim, heatmap_channel=heatmap_channel, output_dim=d_model) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True), @@ -29,14 +29,20 @@ class GinkaMaskGIT(nn.Module): nn.Linear(d_model, num_classes) ) - def forward(self, map: torch.Tensor, cond: torch.Tensor): + def forward(self, map: torch.Tensor, cond: torch.Tensor, heatmap: torch.Tensor): # map: [B, H * W] # cond: [B, cond_dim] + # heatmap: [B, C, H, W] # output: [B, H * W, num_classes] + cond, heatmap = self.cond_encoder(cond, heatmap) + # cond: [B, d_model] + # heatmap: [B, d_model, H, W] - x = self.tile_embedding(map) + self.pos_embedding - c = self.cond_projection(cond).unsqueeze(1) - x = torch.cat([c, x], dim=1) + B, C, H, W = heatmap.shape + + heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1) + x = self.tile_embedding(map) + heatmap + x = torch.cat([cond.unsqueeze(1), x], dim=1) + self.pos_embedding m = self.encoder(x) out = self.decoder(x, m) @@ -50,6 +56,7 @@ if __name__ == "__main__": map = torch.randint(0, 16, [1, 169]).to(device) cond = torch.rand(1, 16).to(device) + heatmap = torch.rand(1, 4, 13, 13).to(device) # 初始化模型 model = GinkaMaskGIT().to(device) @@ -58,7 +65,7 @@ if __name__ == "__main__": # 前向传播 start = time.perf_counter() - output = model(map, cond) + output = model(map, cond, heatmap) end = time.perf_counter() print_memory("前向传播后") @@ -66,7 +73,7 @@ if __name__ == "__main__": print(f"推理耗时: {end - start}") print(f"输出形状: output={output.shape}") print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") - print(f"Projection parameters: {sum(p.numel() for p in model.cond_projection.parameters())}") + print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond_encoder.parameters())}") print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}") diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index fc5f1ba..c9f6a13 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -11,12 +11,10 @@ import cv2 import numpy as np from torch_geometric.loader import DataLoader from tqdm import tqdm -from .transformer.maskGIT import GinkaMaskGIT -from .vae_rnn.loss import VAELoss -from .vae_rnn.scheduler import VAEScheduler -from .dataset import GinkaRNNDataset +from .maskGIT.model import GinkaMaskGIT +from .dataset import GinkaMaskGITDataset from shared.image import matrix_to_image_cv -from .transformer.mask import MapMask +from .maskGIT.mask import MapMask # 手工标注标签定义(暂时不用): # 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, @@ -83,8 +81,8 @@ def train(): model = GinkaMaskGIT(num_classes=NUM_CLASSES).to(device) masker = MapMask([0.5, 0.5]) - dataset = GinkaRNNDataset(args.train, device) - dataset_val = GinkaRNNDataset(args.validate, device) + dataset = GinkaMaskGITDataset(args.train, device) + dataset_val = GinkaMaskGITDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)