From 020c2cf16874cfb2c2fe3a7d1a97f065e7063c13 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 13 Mar 2026 20:14:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=9C=A8=E8=AE=AD=E7=BB=83=E6=97=B6?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=AB=98=E6=96=AF=E6=A8=A1=E7=B3=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 4 ++++ data/src/auto.ts | 4 ++-- data/src/auto/heatmap.ts | 2 ++ ginka/dataset.py | 36 +++++++++++++++++++++++++++++++----- ginka/train_maskGIT.py | 19 +++++++++---------- requirements.txt | 3 ++- 6 files changed, 50 insertions(+), 18 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b8d2fef --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:system", + "python-envs.defaultPackageManager": "ms-python.python:pip" +} diff --git a/data/src/auto.ts b/data/src/auto.ts index 7db67a1..48155af 100644 --- a/data/src/auto.ts +++ b/data/src/auto.ts @@ -268,8 +268,8 @@ const labelConfig: IAutoLabelConfig = { maxFishCount: 2, minEntryCount: 1, maxEntryCount: 4, - guassainRadius: 2, - heatmapKernel: 3, + guassainRadius: 0, + heatmapKernel: 0, ignoreIssues: true, customTowerFilter: info => { // if (info.name !== 'Apeiria') { diff --git a/data/src/auto/heatmap.ts b/data/src/auto/heatmap.ts index 40876fe..0e049a7 100644 --- a/data/src/auto/heatmap.ts +++ b/data/src/auto/heatmap.ts @@ -8,6 +8,7 @@ export function generateHeatmap( tokens: Set, kernel: number = 5 ): number[][] { + if (kernel === 0) return map.map(v => v.slice()); if (kernel % 2 !== 1) { throw new Error(`Kernal size must be odd.`); } @@ -44,6 +45,7 @@ export function generateHeatmap( * @param sigma 标准差 */ export function gaussainHeatmap(map: number[][], sigma: number = 1) { + if (sigma === 0) return map.map(v => v.slice()); const radius = sigma * 3; const width = map[0].length; const height = map.length; diff --git a/ginka/dataset.py b/ginka/dataset.py index 9fcf55d..95b7f95 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -1,5 +1,8 @@ import json +import random import torch +import cv2 +import numpy as np from torch.utils.data import Dataset def load_data(path: str): @@ -11,11 +14,13 @@ def load_data(path: str): data_list.append(value) return data_list - + class GinkaMaskGITDataset(Dataset): - def __init__(self, data_path: str, device): - self.data = load_data(data_path) # 自定义数据加载函数 - self.device = device + def __init__(self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6): + self.data = load_data(data_path) + self.sigma_rand = sigma_rand + self.blur_min = blur_min + self.blur_max = blur_max def __len__(self): return len(self.data) @@ -25,7 +30,28 @@ class GinkaMaskGITDataset(Dataset): target = torch.LongTensor(item['map']) # [H, W] cond = torch.FloatTensor(item['val']) # [cond_dim] - heatmap = torch.FloatTensor(item['heatmap']) # [heatmap_channel, H, W] + heatmap = np.array(item['heatmap']) + + if random.random() < 0.5: + size = random.randint(self.blur_min, self.blur_max) + if size % 2 == 0: + size = size + 1 if random.random() < 0.5 else size - 1 + heatmap = cv2.GaussianBlur(heatmap, (size, size), 0) + else: + sizeX = random.randint(self.blur_min, self.blur_max) + sizeY = random.randint(self.blur_min, self.blur_max) + if sizeX % 2 == 0: + sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1 + if sizeY % 2 == 0: + sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 + heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0) + + heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] + + if random.random() < 0.5: + sigma = random.random() * self.sigma_rand + rand = torch.randn_like(heatmap) * sigma + heatmap = heatmap + rand return { "cond": cond, diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 9faa29d..f1fbc5d 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -44,6 +44,8 @@ GENERATE_STEP = 8 MAP_SIZE = 13 * 13 HEATMAP_CHANNEL = 9 LABEL_SMOOTHING = 0 +BLUR_MIN_SIZE = 3 +BLUR_MAX_SIZE = 6 RAND_RATIO = 0.1 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 @@ -78,13 +80,12 @@ def train(): model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device) masker = MapMask([0.5, 0.5]) - dataset = GinkaMaskGITDataset(args.train, device) - dataset_val = GinkaMaskGITDataset(args.validate, device) + dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE) + dataset_val = GinkaMaskGITDataset(args.validate, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) - # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) # 用于生成图片 @@ -104,8 +105,8 @@ def train(): optimizer.load_state_dict(data_ginka["optim_state"]) print("Train from loaded state.") - - for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): + + for epoch in tqdm(range(args.epochs), desc="MaskGIT Training", disable=disable_tqdm): loss_total = torch.Tensor([0]).to(device) for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): @@ -115,9 +116,6 @@ def train(): B, H, W = target_map.shape target_map = target_map.view(B, H * W) - rand = torch.randn_like(heatmap).to(device) * RAND_RATIO - if random.random() > 0.5: - heatmap = heatmap + rand mask = np.zeros((B, H * W)) for i in range(B): @@ -163,7 +161,7 @@ def train(): gap = 5 color = (255, 255, 255) # 白色 vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): # 1. 常规生成 target_map = batch["target_map"].to(device) cond = batch["cond"].to(device) @@ -226,7 +224,8 @@ def train(): break generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict) - cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img) + img = np.block([[real_img], [vline], [generated_img]]) + cv2.imwrite(f"result/transformer_img/g-{idx}.png", img) avg_loss_val = val_loss_total.item() / len(dataloader_val) tqdm.write( diff --git a/requirements.txt b/requirements.txt index c451531..6134b46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ torch-geometric transformers scipy numpy -cv2 \ No newline at end of file +cv2 +perlin-noise \ No newline at end of file