feat: 降低 MaskGIT 对条件的依赖

This commit is contained in:
unanmed 2026-04-08 19:36:46 +08:00
parent cbbe312444
commit dbb0b9064c
6 changed files with 67 additions and 26 deletions

View File

@ -16,11 +16,16 @@ def load_data(path: str):
return data_list return data_list
class GinkaMaskGITDataset(Dataset): class GinkaMaskGITDataset(Dataset):
def __init__(self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6): def __init__(
self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6,
noise_prob=0.2, drop_prob=0.2
):
self.data = load_data(data_path) self.data = load_data(data_path)
self.sigma_rand = sigma_rand self.sigma_rand = sigma_rand
self.blur_min = blur_min self.blur_min = blur_min
self.blur_max = blur_max self.blur_max = blur_max
self.noise_prob = noise_prob
self.drop_prob = drop_prob
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -47,6 +52,8 @@ class GinkaMaskGITDataset(Dataset):
target_np = np.flipud(target_np) target_np = np.flipud(target_np)
for i in range(0, heatmap.shape[0]): for i in range(0, heatmap.shape[0]):
heatmap[i] = np.flipud(heatmap[i]) heatmap[i] = np.flipud(heatmap[i])
target = torch.LongTensor(target_np.copy()) # [H, W] target = torch.LongTensor(target_np.copy()) # [H, W]
cond = torch.FloatTensor(item['val']) # [cond_dim] cond = torch.FloatTensor(item['val']) # [cond_dim]
@ -65,12 +72,19 @@ class GinkaMaskGITDataset(Dataset):
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0) heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
for i in range(0, heatmap.shape[0]):
if np.random.rand() < self.noise_prob:
sigma = random.random() * self.sigma_rand
heatmap[i] = heatmap * sigma + np.random.randn() * (1 - sigma)
elif np.random.rand() < self.drop_prob:
heatmap[i] = np.zeros_like(heatmap[i])
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
if random.random() < 0.5: if random.random() < 0.5:
sigma = random.random() * self.sigma_rand sigma = random.random() * self.sigma_rand
rand = torch.randn_like(heatmap) * sigma rand = torch.randn_like(heatmap)
heatmap = heatmap + rand heatmap = heatmap * (1 - sigma) + rand * sigma
return { return {
"cond": cond, "cond": cond,

View File

@ -2,7 +2,7 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
from .cond import HeatmapCond from .cond import HeatmapCond
from ..maskGIT.maskGIT import MaskGIT from ..maskGIT.maskGIT import Transformer
from ..utils import print_memory from ..utils import print_memory
class GinkaHeatmapModel(nn.Module): class GinkaHeatmapModel(nn.Module):
@ -15,7 +15,7 @@ class GinkaHeatmapModel(nn.Module):
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model)) self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model))
self.cond = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) self.cond = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) self.transformer = Transformer(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
self.cross_attn = nn.MultiheadAttention(d_model, num_heads=nhead, batch_first=True) self.cross_attn = nn.MultiheadAttention(d_model, num_heads=nhead, batch_first=True)
self.output_fc = nn.Sequential( self.output_fc = nn.Sequential(
nn.Linear(d_model, d_model // 2), nn.Linear(d_model, d_model // 2),

View File

@ -4,19 +4,19 @@ import torch.nn as nn
from ..utils import print_memory from ..utils import print_memory
class GinkaMaskGITCond(nn.Module): class GinkaMaskGITCond(nn.Module):
def __init__(self, heatmap_channel=4, output_dim=256): def __init__(self, input_channel=4, channels=[32, 64, 128]):
super().__init__() super().__init__()
self.heatmap_conv = nn.Sequential( self.heatmap_conv = nn.Sequential(
nn.Conv2d(heatmap_channel, output_dim // 4, kernel_size=3, padding=1, padding_mode='replicate'), nn.Conv2d(input_channel, channels[0], kernel_size=3, padding=1, padding_mode='replicate'),
nn.BatchNorm2d(output_dim // 4), nn.BatchNorm2d(channels[0]),
nn.GELU(), nn.GELU(),
nn.Conv2d(output_dim // 4, output_dim // 2, kernel_size=3, padding=1, padding_mode='replicate'), nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, padding_mode='replicate'),
nn.BatchNorm2d(output_dim // 2), nn.BatchNorm2d(channels[1]),
nn.GELU(), nn.GELU(),
nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, padding=1, padding_mode='replicate'), nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1, padding_mode='replicate'),
nn.BatchNorm2d(output_dim), nn.BatchNorm2d(channels[2]),
nn.GELU() nn.GELU()
) )

View File

@ -1,6 +1,6 @@
import torch.nn as nn import torch.nn as nn
class MaskGIT(nn.Module): class Transformer(nn.Module):
def __init__( def __init__(
self, d_model=256, dim_ff=512, nhead=8, num_layers=4, self, d_model=256, dim_ff=512, nhead=8, num_layers=4,
): ):

View File

@ -1,9 +1,10 @@
import time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..utils import print_memory from ..utils import print_memory
from .cond import GinkaMaskGITCond from .cond import GinkaMaskGITCond
from .maskGIT import MaskGIT from .maskGIT import Transformer
class GinkaMaskGIT(nn.Module): class GinkaMaskGIT(nn.Module):
def __init__( def __init__(
@ -15,9 +16,18 @@ class GinkaMaskGIT(nn.Module):
self.tile_embedding = nn.Embedding(num_classes, d_model) self.tile_embedding = nn.Embedding(num_classes, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model)) self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model))
self.cond_encoder = GinkaMaskGITCond(heatmap_channel=heatmap_channel, output_dim=d_model) cond_channels = [d_model // 4, d_model // 2, d_model]
self.cond_encoder = GinkaMaskGITCond(input_channel=heatmap_channel, channels=cond_channels)
self.cond_gate = nn.Sequential(
nn.Linear(cond_channels[2] * 2, cond_channels[2]),
nn.LayerNorm(cond_channels[2]),
nn.Dropout(0.3),
nn.GELU(),
nn.Linear(cond_channels[2], cond_channels[2])
)
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) self.transformer = Transformer(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
self.output_fc = nn.Sequential( self.output_fc = nn.Sequential(
nn.Linear(d_model, num_classes) nn.Linear(d_model, num_classes)
@ -27,14 +37,15 @@ class GinkaMaskGIT(nn.Module):
# map: [B, H * W] # map: [B, H * W]
# heatmap: [B, C, H, W] # heatmap: [B, C, H, W]
# output: [B, H * W, num_classes] # output: [B, H * W, num_classes]
heatmap = self.cond_encoder(heatmap) heatmap = self.cond_encoder(heatmap) # [B, d_model, H, W]
# cond: [B, d_model]
# heatmap: [B, d_model, H, W]
B, C, H, W = heatmap.shape B, C, H, W = heatmap.shape
heatmap_mean = F.avg_pool2d(heatmap, (H, W)) # [B, d_model, 1, 1]
heatmap_max = F.max_pool2d(heatmap, (H, W)) # [B, d_model, 1, 1]
gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2)
gate = self.cond_gate(gate_input) # [B, d_model]
heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1) heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1)
x = self.tile_embedding(map) + heatmap x = self.tile_embedding(map) + heatmap * torch.sigmoid(gate)
x = x + self.pos_embedding x = x + self.pos_embedding
x = self.transformer(x) x = self.transformer(x)
@ -64,6 +75,7 @@ if __name__ == "__main__":
print(f"输出形状: output={output.shape}") print(f"输出形状: output={output.shape}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond_encoder.parameters())}") print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond_encoder.parameters())}")
print(f"Condition Gate parameters: {sum(p.numel() for p in model.cond_gate.parameters())}")
print(f"MaskGIT parameters: {sum(p.numel() for p in model.transformer.parameters())}") print(f"MaskGIT parameters: {sum(p.numel() for p in model.transformer.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}") print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -47,8 +47,7 @@ NUM_LAYERS_DIFFUSION = 4
D_MODEL_DIFFUSION = 128 D_MODEL_DIFFUSION = 128
T_DIFFUSION = 100 T_DIFFUSION = 100
MIN_MASK = 0 MIN_MASK = 0
MAX_MASK = 0.8 MAX_MASK = 1
NOISE_SAMPLING_K = [40, 15, 21, 8, 8, 4, 1, 2, 10]
W = 5 # CFG 参数 W = 5 # CFG 参数
device = torch.device( device = torch.device(
@ -185,7 +184,7 @@ def train():
pred_noise = model(x_t, cond_heatmap, t) pred_noise = model(x_t, cond_heatmap, t)
loss = F.mse_loss(pred_noise, noise) loss = F.l1_loss(pred_noise, noise)
val_loss_total += loss.detach() val_loss_total += loss.detach()
@ -202,9 +201,10 @@ def train():
if args.use_maskgit: if args.use_maskgit:
for i in range(0, 5): for i in range(0, 5):
ar = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W]) ar = np.ndarray([1, HEATMAP_CHANNEL, MAP_H, MAP_W])
k = get_nms_sampling_count()
for c in range(0, HEATMAP_CHANNEL): for c in range(0, HEATMAP_CHANNEL):
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H,0:MAP_W] noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H,0:MAP_W]
ar[0,c] = nms_sampling(noise, NOISE_SAMPLING_K[c]) ar[0,c] = nms_sampling(noise, k[c])
map = full_generate(model, maskGIT, torch.FloatTensor(ar).to(device), diffusion) map = full_generate(model, maskGIT, torch.FloatTensor(ar).to(device), diffusion)
generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict) generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict)
@ -221,12 +221,27 @@ def train():
"model_state": maskGIT.state_dict(), "model_state": maskGIT.state_dict(),
}, f"result/ginka_heatmap.pth") }, f"result/ginka_heatmap.pth")
def get_nms_sampling_count():
return [
np.random.randint(20, 40),
np.random.randint(10, 20),
np.random.randint(10, 30),
np.random.randint(4, 12),
np.random.randint(4, 12),
np.random.randint(2, 6),
np.random.randint(0, 2),
np.random.randint(1, 3),
np.random.randint(2, 10)
]
@torch.no_grad()
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
@torch.no_grad()
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor): def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device) map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
for i in range(GENERATE_STEP): for i in range(GENERATE_STEP):