mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 在训练时处理高斯模糊
This commit is contained in:
parent
d0f86018f1
commit
020c2cf168
4
.vscode/settings.json
vendored
Normal file
4
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:system",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:pip"
|
||||
}
|
||||
@ -268,8 +268,8 @@ const labelConfig: IAutoLabelConfig = {
|
||||
maxFishCount: 2,
|
||||
minEntryCount: 1,
|
||||
maxEntryCount: 4,
|
||||
guassainRadius: 2,
|
||||
heatmapKernel: 3,
|
||||
guassainRadius: 0,
|
||||
heatmapKernel: 0,
|
||||
ignoreIssues: true,
|
||||
customTowerFilter: info => {
|
||||
// if (info.name !== 'Apeiria') {
|
||||
|
||||
@ -8,6 +8,7 @@ export function generateHeatmap(
|
||||
tokens: Set<number>,
|
||||
kernel: number = 5
|
||||
): number[][] {
|
||||
if (kernel === 0) return map.map(v => v.slice());
|
||||
if (kernel % 2 !== 1) {
|
||||
throw new Error(`Kernal size must be odd.`);
|
||||
}
|
||||
@ -44,6 +45,7 @@ export function generateHeatmap(
|
||||
* @param sigma 标准差
|
||||
*/
|
||||
export function gaussainHeatmap(map: number[][], sigma: number = 1) {
|
||||
if (sigma === 0) return map.map(v => v.slice());
|
||||
const radius = sigma * 3;
|
||||
const width = map[0].length;
|
||||
const height = map.length;
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
def load_data(path: str):
|
||||
@ -13,9 +16,11 @@ def load_data(path: str):
|
||||
return data_list
|
||||
|
||||
class GinkaMaskGITDataset(Dataset):
|
||||
def __init__(self, data_path: str, device):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
self.device = device
|
||||
def __init__(self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6):
|
||||
self.data = load_data(data_path)
|
||||
self.sigma_rand = sigma_rand
|
||||
self.blur_min = blur_min
|
||||
self.blur_max = blur_max
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -25,7 +30,28 @@ class GinkaMaskGITDataset(Dataset):
|
||||
|
||||
target = torch.LongTensor(item['map']) # [H, W]
|
||||
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
||||
heatmap = torch.FloatTensor(item['heatmap']) # [heatmap_channel, H, W]
|
||||
heatmap = np.array(item['heatmap'])
|
||||
|
||||
if random.random() < 0.5:
|
||||
size = random.randint(self.blur_min, self.blur_max)
|
||||
if size % 2 == 0:
|
||||
size = size + 1 if random.random() < 0.5 else size - 1
|
||||
heatmap = cv2.GaussianBlur(heatmap, (size, size), 0)
|
||||
else:
|
||||
sizeX = random.randint(self.blur_min, self.blur_max)
|
||||
sizeY = random.randint(self.blur_min, self.blur_max)
|
||||
if sizeX % 2 == 0:
|
||||
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
|
||||
if sizeY % 2 == 0:
|
||||
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||||
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
|
||||
|
||||
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||||
|
||||
if random.random() < 0.5:
|
||||
sigma = random.random() * self.sigma_rand
|
||||
rand = torch.randn_like(heatmap) * sigma
|
||||
heatmap = heatmap + rand
|
||||
|
||||
return {
|
||||
"cond": cond,
|
||||
|
||||
@ -44,6 +44,8 @@ GENERATE_STEP = 8
|
||||
MAP_SIZE = 13 * 13
|
||||
HEATMAP_CHANNEL = 9
|
||||
LABEL_SMOOTHING = 0
|
||||
BLUR_MIN_SIZE = 3
|
||||
BLUR_MAX_SIZE = 6
|
||||
RAND_RATIO = 0.1
|
||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||
|
||||
@ -78,13 +80,12 @@ def train():
|
||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
|
||||
masker = MapMask([0.5, 0.5])
|
||||
|
||||
dataset = GinkaMaskGITDataset(args.train, device)
|
||||
dataset_val = GinkaMaskGITDataset(args.validate, device)
|
||||
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||
dataset_val = GinkaMaskGITDataset(args.validate, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
||||
|
||||
# 用于生成图片
|
||||
@ -105,7 +106,7 @@ def train():
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
||||
for epoch in tqdm(range(args.epochs), desc="MaskGIT Training", disable=disable_tqdm):
|
||||
loss_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
@ -115,9 +116,6 @@ def train():
|
||||
B, H, W = target_map.shape
|
||||
|
||||
target_map = target_map.view(B, H * W)
|
||||
rand = torch.randn_like(heatmap).to(device) * RAND_RATIO
|
||||
if random.random() > 0.5:
|
||||
heatmap = heatmap + rand
|
||||
|
||||
mask = np.zeros((B, H * W))
|
||||
for i in range(B):
|
||||
@ -163,7 +161,7 @@ def train():
|
||||
gap = 5
|
||||
color = (255, 255, 255) # 白色
|
||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
# 1. 常规生成
|
||||
target_map = batch["target_map"].to(device)
|
||||
cond = batch["cond"].to(device)
|
||||
@ -226,7 +224,8 @@ def train():
|
||||
break
|
||||
|
||||
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
|
||||
cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img)
|
||||
img = np.block([[real_img], [vline], [generated_img]])
|
||||
cv2.imwrite(f"result/transformer_img/g-{idx}.png", img)
|
||||
|
||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
tqdm.write(
|
||||
|
||||
@ -5,3 +5,4 @@ transformers
|
||||
scipy
|
||||
numpy
|
||||
cv2
|
||||
perlin-noise
|
||||
Loading…
Reference in New Issue
Block a user