mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 12:57:15 +08:00
feat: 分块掩码方式
This commit is contained in:
parent
8672c52ff5
commit
0cb22e9cf7
@ -16,6 +16,7 @@ from .vae_rnn.loss import VAELoss
|
||||
from .vae_rnn.scheduler import VAEScheduler
|
||||
from .dataset import GinkaRNNDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
from .transformer.mask import MapMask
|
||||
|
||||
# 手工标注标签定义(暂时不用):
|
||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
||||
@ -52,6 +53,7 @@ NUM_CLASSES = 16
|
||||
MASK_TOKEN = 15
|
||||
GENERATE_STEP = 8
|
||||
MAP_SIZE = 13 * 13
|
||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -82,6 +84,7 @@ def train():
|
||||
args = parse_arguments()
|
||||
|
||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES).to(device)
|
||||
masker = MapMask([0.5, 0.5])
|
||||
|
||||
dataset = GinkaRNNDataset(args.train, device)
|
||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||
@ -98,6 +101,7 @@ def train():
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# 接续训练
|
||||
if args.resume:
|
||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||
|
||||
@ -118,19 +122,20 @@ def train():
|
||||
B, H, W = target_map.shape
|
||||
target_map = target_map.view(B, H * W)
|
||||
|
||||
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
|
||||
r = torch.rand(B).to(device)
|
||||
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
|
||||
mask = np.zeros((B, H * W))
|
||||
for i in range(B):
|
||||
mask[i] = masker.mask(H, W)
|
||||
|
||||
# 2. 生成掩码矩阵
|
||||
masks = torch.rand(target_map.shape).to(device) < r
|
||||
mask = torch.from_numpy(mask).to(torch.bool)
|
||||
|
||||
# 掩码
|
||||
masked_input = target_map.clone()
|
||||
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||
|
||||
logits = model(masked_input, cond)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1)
|
||||
loss = (loss * masks).sum() / (masks.sum() + 1e-6)
|
||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@ -168,19 +173,20 @@ def train():
|
||||
B, H, W = target_map.shape
|
||||
target_map = target_map.view(B, H * W)
|
||||
|
||||
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
|
||||
r = torch.rand(B).to(device)
|
||||
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
|
||||
mask = np.zeros((B, H * W))
|
||||
for i in range(B):
|
||||
mask[i] = masker.mask(H, W)
|
||||
|
||||
mask = torch.from_numpy(mask).to(torch.bool)
|
||||
|
||||
# 2. 生成掩码矩阵
|
||||
masks = torch.rand(target_map.shape).to(device) < r
|
||||
masked_input = target_map.clone()
|
||||
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||
|
||||
logits = model(masked_input, cond)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||
loss = (loss * masks.view(-1)).sum() / (masks.sum() + 1e-6)
|
||||
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
val_loss_total += loss.detach()
|
||||
|
||||
|
||||
76
ginka/transformer/mask.py
Normal file
76
ginka/transformer/mask.py
Normal file
@ -0,0 +1,76 @@
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.ndimage import binary_dilation, binary_erosion
|
||||
|
||||
class MapMask:
|
||||
def __init__(self, probs: list[float] = [0.5, 0.5]):
|
||||
# 掩码方案
|
||||
# 0: 纯随机掩码
|
||||
# 1: 分块随机掩码
|
||||
self.probs = [sum(probs[0:i+1]) for i in range(len(probs))]
|
||||
|
||||
def _sample_mask_ratio(self, alpha=2, beta=2, min_ratio=0.05, max_ratio=1):
|
||||
r = np.random.beta(alpha, beta)
|
||||
r = min_ratio + (max_ratio - min_ratio) * r
|
||||
return r
|
||||
|
||||
def mask(self, h: int, w: int):
|
||||
test = random.random()
|
||||
mask = None
|
||||
if test < self.probs[0]:
|
||||
mask = self.mask_random(h, w)
|
||||
elif test < self.probs[1]:
|
||||
mask = self.block_mask(h, w)
|
||||
|
||||
mask = self.random_morphology(mask)
|
||||
return mask.reshape(h * w)
|
||||
|
||||
def mask_random(self, h: int, w: int):
|
||||
# 纯随机掩码
|
||||
ratio = self._sample_mask_ratio()
|
||||
total = h * w
|
||||
num = int(total * ratio)
|
||||
|
||||
idx = np.random.choice(total, num, replace=False)
|
||||
|
||||
mask = np.zeros(total, dtype=bool)
|
||||
mask[idx] = True
|
||||
|
||||
return mask.reshape(h, w)
|
||||
|
||||
def block_mask(self, h: int, w: int, min_block=2, max_block=None):
|
||||
# 分块随机掩码
|
||||
ratio = self._sample_mask_ratio()
|
||||
if max_block is None:
|
||||
max_block = min(h, w) // 2
|
||||
|
||||
target = int(h * w * ratio)
|
||||
mask = np.zeros((h, w), dtype=bool)
|
||||
|
||||
while mask.sum() < target:
|
||||
|
||||
bw = np.random.randint(min_block, max_block + 1)
|
||||
bh = np.random.randint(min_block, max_block + 1)
|
||||
|
||||
x = np.random.randint(0, h - bh + 1)
|
||||
y = np.random.randint(0, w - bw + 1)
|
||||
|
||||
mask[x:x + bh, y:y + bw] = True
|
||||
|
||||
return mask
|
||||
|
||||
def random_morphology(self, mask, max_iter=2):
|
||||
op = np.random.choice(["none", "dilate", "erode"])
|
||||
|
||||
if op == "none":
|
||||
return mask
|
||||
|
||||
it = np.random.randint(1, max_iter + 1)
|
||||
|
||||
if op == "dilate":
|
||||
return binary_dilation(mask, iterations=it)
|
||||
|
||||
if op == "erode":
|
||||
return binary_erosion(mask, iterations=it)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user