mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 在训练时处理高斯模糊
This commit is contained in:
parent
d0f86018f1
commit
020c2cf168
4
.vscode/settings.json
vendored
Normal file
4
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"python-envs.defaultEnvManager": "ms-python.python:system",
|
||||||
|
"python-envs.defaultPackageManager": "ms-python.python:pip"
|
||||||
|
}
|
||||||
@ -268,8 +268,8 @@ const labelConfig: IAutoLabelConfig = {
|
|||||||
maxFishCount: 2,
|
maxFishCount: 2,
|
||||||
minEntryCount: 1,
|
minEntryCount: 1,
|
||||||
maxEntryCount: 4,
|
maxEntryCount: 4,
|
||||||
guassainRadius: 2,
|
guassainRadius: 0,
|
||||||
heatmapKernel: 3,
|
heatmapKernel: 0,
|
||||||
ignoreIssues: true,
|
ignoreIssues: true,
|
||||||
customTowerFilter: info => {
|
customTowerFilter: info => {
|
||||||
// if (info.name !== 'Apeiria') {
|
// if (info.name !== 'Apeiria') {
|
||||||
|
|||||||
@ -8,6 +8,7 @@ export function generateHeatmap(
|
|||||||
tokens: Set<number>,
|
tokens: Set<number>,
|
||||||
kernel: number = 5
|
kernel: number = 5
|
||||||
): number[][] {
|
): number[][] {
|
||||||
|
if (kernel === 0) return map.map(v => v.slice());
|
||||||
if (kernel % 2 !== 1) {
|
if (kernel % 2 !== 1) {
|
||||||
throw new Error(`Kernal size must be odd.`);
|
throw new Error(`Kernal size must be odd.`);
|
||||||
}
|
}
|
||||||
@ -44,6 +45,7 @@ export function generateHeatmap(
|
|||||||
* @param sigma 标准差
|
* @param sigma 标准差
|
||||||
*/
|
*/
|
||||||
export function gaussainHeatmap(map: number[][], sigma: number = 1) {
|
export function gaussainHeatmap(map: number[][], sigma: number = 1) {
|
||||||
|
if (sigma === 0) return map.map(v => v.slice());
|
||||||
const radius = sigma * 3;
|
const radius = sigma * 3;
|
||||||
const width = map[0].length;
|
const width = map[0].length;
|
||||||
const height = map.length;
|
const height = map.length;
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
@ -13,9 +16,11 @@ def load_data(path: str):
|
|||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
class GinkaMaskGITDataset(Dataset):
|
class GinkaMaskGITDataset(Dataset):
|
||||||
def __init__(self, data_path: str, device):
|
def __init__(self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path)
|
||||||
self.device = device
|
self.sigma_rand = sigma_rand
|
||||||
|
self.blur_min = blur_min
|
||||||
|
self.blur_max = blur_max
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -25,7 +30,28 @@ class GinkaMaskGITDataset(Dataset):
|
|||||||
|
|
||||||
target = torch.LongTensor(item['map']) # [H, W]
|
target = torch.LongTensor(item['map']) # [H, W]
|
||||||
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
||||||
heatmap = torch.FloatTensor(item['heatmap']) # [heatmap_channel, H, W]
|
heatmap = np.array(item['heatmap'])
|
||||||
|
|
||||||
|
if random.random() < 0.5:
|
||||||
|
size = random.randint(self.blur_min, self.blur_max)
|
||||||
|
if size % 2 == 0:
|
||||||
|
size = size + 1 if random.random() < 0.5 else size - 1
|
||||||
|
heatmap = cv2.GaussianBlur(heatmap, (size, size), 0)
|
||||||
|
else:
|
||||||
|
sizeX = random.randint(self.blur_min, self.blur_max)
|
||||||
|
sizeY = random.randint(self.blur_min, self.blur_max)
|
||||||
|
if sizeX % 2 == 0:
|
||||||
|
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
|
||||||
|
if sizeY % 2 == 0:
|
||||||
|
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||||||
|
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
|
||||||
|
|
||||||
|
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||||||
|
|
||||||
|
if random.random() < 0.5:
|
||||||
|
sigma = random.random() * self.sigma_rand
|
||||||
|
rand = torch.randn_like(heatmap) * sigma
|
||||||
|
heatmap = heatmap + rand
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cond": cond,
|
"cond": cond,
|
||||||
|
|||||||
@ -44,6 +44,8 @@ GENERATE_STEP = 8
|
|||||||
MAP_SIZE = 13 * 13
|
MAP_SIZE = 13 * 13
|
||||||
HEATMAP_CHANNEL = 9
|
HEATMAP_CHANNEL = 9
|
||||||
LABEL_SMOOTHING = 0
|
LABEL_SMOOTHING = 0
|
||||||
|
BLUR_MIN_SIZE = 3
|
||||||
|
BLUR_MAX_SIZE = 6
|
||||||
RAND_RATIO = 0.1
|
RAND_RATIO = 0.1
|
||||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||||
|
|
||||||
@ -78,13 +80,12 @@ def train():
|
|||||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
|
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
|
||||||
masker = MapMask([0.5, 0.5])
|
masker = MapMask([0.5, 0.5])
|
||||||
|
|
||||||
dataset = GinkaMaskGITDataset(args.train, device)
|
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||||
dataset_val = GinkaMaskGITDataset(args.validate, device)
|
dataset_val = GinkaMaskGITDataset(args.validate, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||||
|
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
||||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
@ -105,7 +106,7 @@ def train():
|
|||||||
|
|
||||||
print("Train from loaded state.")
|
print("Train from loaded state.")
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
for epoch in tqdm(range(args.epochs), desc="MaskGIT Training", disable=disable_tqdm):
|
||||||
loss_total = torch.Tensor([0]).to(device)
|
loss_total = torch.Tensor([0]).to(device)
|
||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
@ -115,9 +116,6 @@ def train():
|
|||||||
B, H, W = target_map.shape
|
B, H, W = target_map.shape
|
||||||
|
|
||||||
target_map = target_map.view(B, H * W)
|
target_map = target_map.view(B, H * W)
|
||||||
rand = torch.randn_like(heatmap).to(device) * RAND_RATIO
|
|
||||||
if random.random() > 0.5:
|
|
||||||
heatmap = heatmap + rand
|
|
||||||
|
|
||||||
mask = np.zeros((B, H * W))
|
mask = np.zeros((B, H * W))
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
@ -163,7 +161,7 @@ def train():
|
|||||||
gap = 5
|
gap = 5
|
||||||
color = (255, 255, 255) # 白色
|
color = (255, 255, 255) # 白色
|
||||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||||
# 1. 常规生成
|
# 1. 常规生成
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
cond = batch["cond"].to(device)
|
cond = batch["cond"].to(device)
|
||||||
@ -226,7 +224,8 @@ def train():
|
|||||||
break
|
break
|
||||||
|
|
||||||
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
|
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
|
||||||
cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img)
|
img = np.block([[real_img], [vline], [generated_img]])
|
||||||
|
cv2.imwrite(f"result/transformer_img/g-{idx}.png", img)
|
||||||
|
|
||||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
|
|||||||
@ -5,3 +5,4 @@ transformers
|
|||||||
scipy
|
scipy
|
||||||
numpy
|
numpy
|
||||||
cv2
|
cv2
|
||||||
|
perlin-noise
|
||||||
Loading…
Reference in New Issue
Block a user