fix: MaskGIT 训练调整

This commit is contained in:
unanmed 2026-04-09 12:47:18 +08:00
parent 9814edd1b4
commit df23c891c6
3 changed files with 7 additions and 7 deletions

View File

@ -18,7 +18,7 @@ def load_data(path: str):
class GinkaMaskGITDataset(Dataset):
def __init__(
self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6,
noise_prob=0.2, drop_prob=0.2
noise_prob=0.2, drop_prob=0.2, noise_sigma=0.1
):
self.data = load_data(data_path)
self.sigma_rand = sigma_rand
@ -26,6 +26,7 @@ class GinkaMaskGITDataset(Dataset):
self.blur_max = blur_max
self.noise_prob = noise_prob
self.drop_prob = drop_prob
self.noise_sigma = noise_sigma
def __len__(self):
return len(self.data)
@ -52,8 +53,6 @@ class GinkaMaskGITDataset(Dataset):
target_np = np.flipud(target_np)
for i in range(0, heatmap.shape[0]):
heatmap[i] = np.flipud(heatmap[i])
target = torch.LongTensor(target_np.copy()) # [H, W]
cond = torch.FloatTensor(item['val']) # [cond_dim]
@ -76,8 +75,8 @@ class GinkaMaskGITDataset(Dataset):
for i in range(0, heatmap.shape[0]):
if np.random.rand() < self.noise_prob:
sigma = random.random() * self.sigma_rand
heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
sigma = random.random() * self.noise_sigma
heatmap[i] = heatmap[i] * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
elif np.random.rand() < self.drop_prob:
heatmap[i] = torch.zeros_like(heatmap[i])

View File

@ -44,8 +44,9 @@ class GinkaMaskGIT(nn.Module):
gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2)
gate = self.cond_gate(gate_input) # [B, d_model]
heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2)
heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1)
x = self.tile_embedding(map) + heatmap * torch.sigmoid(gate)
x = self.tile_embedding(map) + heatmap
x = x + self.pos_embedding
x = self.transformer(x)

View File

@ -46,7 +46,7 @@ HEATMAP_CHANNEL = 9
LABEL_SMOOTHING = 0
BLUR_MIN_SIZE = 3
BLUR_MAX_SIZE = 9
RAND_RATIO = 0.15
RAND_RATIO = 0.3
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
NUM_LAYERS = 4
D_MODEL = 128