mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: MaskGIT 训练调整
This commit is contained in:
parent
9814edd1b4
commit
df23c891c6
@ -18,7 +18,7 @@ def load_data(path: str):
|
|||||||
class GinkaMaskGITDataset(Dataset):
|
class GinkaMaskGITDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6,
|
self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6,
|
||||||
noise_prob=0.2, drop_prob=0.2
|
noise_prob=0.2, drop_prob=0.2, noise_sigma=0.1
|
||||||
):
|
):
|
||||||
self.data = load_data(data_path)
|
self.data = load_data(data_path)
|
||||||
self.sigma_rand = sigma_rand
|
self.sigma_rand = sigma_rand
|
||||||
@ -26,6 +26,7 @@ class GinkaMaskGITDataset(Dataset):
|
|||||||
self.blur_max = blur_max
|
self.blur_max = blur_max
|
||||||
self.noise_prob = noise_prob
|
self.noise_prob = noise_prob
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
self.noise_sigma = noise_sigma
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -52,8 +53,6 @@ class GinkaMaskGITDataset(Dataset):
|
|||||||
target_np = np.flipud(target_np)
|
target_np = np.flipud(target_np)
|
||||||
for i in range(0, heatmap.shape[0]):
|
for i in range(0, heatmap.shape[0]):
|
||||||
heatmap[i] = np.flipud(heatmap[i])
|
heatmap[i] = np.flipud(heatmap[i])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
target = torch.LongTensor(target_np.copy()) # [H, W]
|
target = torch.LongTensor(target_np.copy()) # [H, W]
|
||||||
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
||||||
@ -76,8 +75,8 @@ class GinkaMaskGITDataset(Dataset):
|
|||||||
|
|
||||||
for i in range(0, heatmap.shape[0]):
|
for i in range(0, heatmap.shape[0]):
|
||||||
if np.random.rand() < self.noise_prob:
|
if np.random.rand() < self.noise_prob:
|
||||||
sigma = random.random() * self.sigma_rand
|
sigma = random.random() * self.noise_sigma
|
||||||
heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
|
heatmap[i] = heatmap[i] * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
|
||||||
elif np.random.rand() < self.drop_prob:
|
elif np.random.rand() < self.drop_prob:
|
||||||
heatmap[i] = torch.zeros_like(heatmap[i])
|
heatmap[i] = torch.zeros_like(heatmap[i])
|
||||||
|
|
||||||
|
|||||||
@ -44,8 +44,9 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2)
|
gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2)
|
||||||
gate = self.cond_gate(gate_input) # [B, d_model]
|
gate = self.cond_gate(gate_input) # [B, d_model]
|
||||||
|
|
||||||
|
heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2)
|
||||||
heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1)
|
heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1)
|
||||||
x = self.tile_embedding(map) + heatmap * torch.sigmoid(gate)
|
x = self.tile_embedding(map) + heatmap
|
||||||
x = x + self.pos_embedding
|
x = x + self.pos_embedding
|
||||||
x = self.transformer(x)
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
|||||||
@ -46,7 +46,7 @@ HEATMAP_CHANNEL = 9
|
|||||||
LABEL_SMOOTHING = 0
|
LABEL_SMOOTHING = 0
|
||||||
BLUR_MIN_SIZE = 3
|
BLUR_MIN_SIZE = 3
|
||||||
BLUR_MAX_SIZE = 9
|
BLUR_MAX_SIZE = 9
|
||||||
RAND_RATIO = 0.15
|
RAND_RATIO = 0.3
|
||||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
D_MODEL = 128
|
D_MODEL = 128
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user