feat: 在训练时处理高斯模糊

This commit is contained in:
unanmed 2026-03-13 20:14:20 +08:00
parent d0f86018f1
commit 020c2cf168
6 changed files with 50 additions and 18 deletions

4
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,4 @@
{
"python-envs.defaultEnvManager": "ms-python.python:system",
"python-envs.defaultPackageManager": "ms-python.python:pip"
}

View File

@ -268,8 +268,8 @@ const labelConfig: IAutoLabelConfig = {
maxFishCount: 2,
minEntryCount: 1,
maxEntryCount: 4,
guassainRadius: 2,
heatmapKernel: 3,
guassainRadius: 0,
heatmapKernel: 0,
ignoreIssues: true,
customTowerFilter: info => {
// if (info.name !== 'Apeiria') {

View File

@ -8,6 +8,7 @@ export function generateHeatmap(
tokens: Set<number>,
kernel: number = 5
): number[][] {
if (kernel === 0) return map.map(v => v.slice());
if (kernel % 2 !== 1) {
throw new Error(`Kernal size must be odd.`);
}
@ -44,6 +45,7 @@ export function generateHeatmap(
* @param sigma
*/
export function gaussainHeatmap(map: number[][], sigma: number = 1) {
if (sigma === 0) return map.map(v => v.slice());
const radius = sigma * 3;
const width = map[0].length;
const height = map.length;

View File

@ -1,5 +1,8 @@
import json
import random
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
def load_data(path: str):
@ -13,9 +16,11 @@ def load_data(path: str):
return data_list
class GinkaMaskGITDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
self.device = device
def __init__(self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6):
self.data = load_data(data_path)
self.sigma_rand = sigma_rand
self.blur_min = blur_min
self.blur_max = blur_max
def __len__(self):
return len(self.data)
@ -25,7 +30,28 @@ class GinkaMaskGITDataset(Dataset):
target = torch.LongTensor(item['map']) # [H, W]
cond = torch.FloatTensor(item['val']) # [cond_dim]
heatmap = torch.FloatTensor(item['heatmap']) # [heatmap_channel, H, W]
heatmap = np.array(item['heatmap'])
if random.random() < 0.5:
size = random.randint(self.blur_min, self.blur_max)
if size % 2 == 0:
size = size + 1 if random.random() < 0.5 else size - 1
heatmap = cv2.GaussianBlur(heatmap, (size, size), 0)
else:
sizeX = random.randint(self.blur_min, self.blur_max)
sizeY = random.randint(self.blur_min, self.blur_max)
if sizeX % 2 == 0:
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
if sizeY % 2 == 0:
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
if random.random() < 0.5:
sigma = random.random() * self.sigma_rand
rand = torch.randn_like(heatmap) * sigma
heatmap = heatmap + rand
return {
"cond": cond,

View File

@ -44,6 +44,8 @@ GENERATE_STEP = 8
MAP_SIZE = 13 * 13
HEATMAP_CHANNEL = 9
LABEL_SMOOTHING = 0
BLUR_MIN_SIZE = 3
BLUR_MAX_SIZE = 6
RAND_RATIO = 0.1
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
@ -78,13 +80,12 @@ def train():
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
masker = MapMask([0.5, 0.5])
dataset = GinkaMaskGITDataset(args.train, device)
dataset_val = GinkaMaskGITDataset(args.validate, device)
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
dataset_val = GinkaMaskGITDataset(args.validate, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
# 用于生成图片
@ -105,7 +106,7 @@ def train():
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
for epoch in tqdm(range(args.epochs), desc="MaskGIT Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
@ -115,9 +116,6 @@ def train():
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
rand = torch.randn_like(heatmap).to(device) * RAND_RATIO
if random.random() > 0.5:
heatmap = heatmap + rand
mask = np.zeros((B, H * W))
for i in range(B):
@ -163,7 +161,7 @@ def train():
gap = 5
color = (255, 255, 255) # 白色
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
# 1. 常规生成
target_map = batch["target_map"].to(device)
cond = batch["cond"].to(device)
@ -226,7 +224,8 @@ def train():
break
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img)
img = np.block([[real_img], [vline], [generated_img]])
cv2.imwrite(f"result/transformer_img/g-{idx}.png", img)
avg_loss_val = val_loss_total.item() / len(dataloader_val)
tqdm.write(

View File

@ -5,3 +5,4 @@ transformers
scipy
numpy
cv2
perlin-noise