mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 18:31:13 +08:00
feat: 调整码字容量与策略
This commit is contained in:
parent
306d585a28
commit
416aa4dd72
@ -28,13 +28,13 @@ class GinkaMaskGIT(nn.Module):
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
# 剩余密度投影:将 5 个浮点数 [wall, door, monster, entrance, resource] 投影为 d_z 维 token
|
||||
self.remain_proj = nn.Linear(5, d_z)
|
||||
self.remain_proj = nn.Linear(1, d_z)
|
||||
|
||||
# z 投影:逐 token 线性变换,保持序列结构
|
||||
self.z_proj = nn.Linear(d_z, d_z)
|
||||
|
||||
# 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个剩余密度 token
|
||||
self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model)
|
||||
self.cond_proj = nn.Linear((z_seq_len + 2 + 5) * d_z, d_model)
|
||||
|
||||
# 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层
|
||||
self.transformer = Transformer(
|
||||
@ -62,7 +62,7 @@ class GinkaMaskGIT(nn.Module):
|
||||
], dim=1)
|
||||
|
||||
# 剩余密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z]
|
||||
e_remain = self.remain_proj(remain).unsqueeze(1)
|
||||
e_remain = self.remain_proj(remain.unsqueeze(-1))
|
||||
|
||||
# z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
|
||||
z_proj = self.z_proj(z)
|
||||
|
||||
@ -23,8 +23,8 @@ from shared.image import matrix_to_image_cv
|
||||
#
|
||||
# 整体架构:
|
||||
# VQ-VAE(三组独立编码器 vq1/vq2/vq3)将三阶段地图上下文分别编码为离散潜变量,
|
||||
# 再由共用 VectorQuantizer 统一量化为 z_q;
|
||||
# 三个独立 MaskGIT(mg1/mg2/mg3)分别以 z_q 和 struct_inject 为条件,
|
||||
# 再由三个独立 VectorQuantizer 分别量化为 z_q1/z_q2/z_q3;
|
||||
# 三个独立 MaskGIT(mg1/mg2/mg3)分别以各自阶段 z_q 和 struct_inject 为条件,
|
||||
# 逐阶段迭代解码地图图块序列。
|
||||
#
|
||||
# 三阶段生成目标:
|
||||
@ -37,14 +37,14 @@ from shared.image import matrix_to_image_cv
|
||||
|
||||
# 共用 VQ-VAE 超参
|
||||
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
||||
VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
||||
VQ_K = 8 # codebook 大小(离散码本条目数)
|
||||
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
||||
VQ_K = 16 # codebook 大小(离散码本条目数)
|
||||
VQ_D_Z = 64 # 码字维度
|
||||
VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
||||
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
|
||||
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
|
||||
VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度
|
||||
VQ_LAYERS = 6 # VQ-VAE Transformer 层数
|
||||
VQ_DIM_FF = 1024 # VQ-VAE 前馈网络隐层维度
|
||||
VQ_D_MODEL = 256 # VQ-VAE Transformer 模型维度
|
||||
VQ_NHEAD = 4 # VQ-VAE 多头注意力头数
|
||||
|
||||
# 第一阶段 MaskGIT 超参
|
||||
@ -65,10 +65,10 @@ STAGE3_MG_NHEAD = 4
|
||||
STAGE3_MG_NUM_LAYERS = 6
|
||||
STAGE3_MG_DIM_FF = 1024
|
||||
|
||||
# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例)
|
||||
STAGE1_FOCAL_WEIGHT = 1.0
|
||||
STAGE2_FOCAL_WEIGHT = 1.0
|
||||
STAGE3_FOCAL_WEIGHT = 1.0
|
||||
# 三阶段 Cross Entropy 损失权重(可调节各阶段对总损失的贡献比例)
|
||||
STAGE1_CE_WEIGHT = 1.0
|
||||
STAGE2_CE_WEIGHT = 1.0
|
||||
STAGE3_CE_WEIGHT = 1.0
|
||||
|
||||
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
|
||||
STAGE1_VQ_WEIGHT = 0.5
|
||||
@ -95,7 +95,6 @@ MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
|
||||
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
|
||||
|
||||
# 损失参数
|
||||
FOCAL_GAMMA = 2.0 # Focal Loss 参数
|
||||
VQ_BETA = 0.5 # 承诺损失权重
|
||||
|
||||
# 训练超参
|
||||
@ -105,6 +104,7 @@ MIN_LR = 1e-6 # 余弦退火最低学习率
|
||||
WEIGHT_DECAY = 1e-4 # L2 正则化系数
|
||||
EPOCHS = 400 # 总训练轮数
|
||||
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
||||
REFERENCE_SAMPLE_PROB = 0.2 # 训练时将参考掩码图无梯度自采样 1-3 步的概率
|
||||
|
||||
device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available()
|
||||
@ -131,7 +131,7 @@ def parse_arguments():
|
||||
|
||||
def build_model(device: torch.device):
|
||||
# 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3)
|
||||
# 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer
|
||||
# 输出形状均为 [B, L, d_z],分别送入各自阶段的 quantizer
|
||||
vq_kwargs = dict(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
|
||||
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_h=MAP_H, map_w=MAP_W
|
||||
@ -140,21 +140,21 @@ def build_model(device: torch.device):
|
||||
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance)
|
||||
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文(resource)
|
||||
|
||||
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
|
||||
# 三个独立 MaskGIT 解码器,分别接收各自阶段的 z_q 作为条件
|
||||
mg1 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
|
||||
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||
z_seq_len=VQ_L * 3
|
||||
z_seq_len=VQ_L
|
||||
).to(device)
|
||||
mg2 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
|
||||
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||
z_seq_len=VQ_L * 3
|
||||
z_seq_len=VQ_L
|
||||
).to(device)
|
||||
mg3 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
|
||||
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||
z_seq_len=VQ_L * 3
|
||||
z_seq_len=VQ_L
|
||||
).to(device)
|
||||
|
||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
||||
@ -166,18 +166,106 @@ def build_model(device: torch.device):
|
||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
# 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表
|
||||
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
# 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰
|
||||
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizers = (quantizer1, quantizer2, quantizer3)
|
||||
|
||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
|
||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
|
||||
|
||||
def focal_loss(logits, target):
|
||||
def cross_entropy_loss(logits, target):
|
||||
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
|
||||
ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none')
|
||||
pt = torch.exp(-ce) # pt = 模型对正确类的预测概率
|
||||
# Focal Loss:对高置信度样本降低权重,让模型更专注于难样本
|
||||
focal = ((1 - pt) ** FOCAL_GAMMA) * ce
|
||||
return focal.mean()
|
||||
return F.cross_entropy(logits.permute(0, 2, 1), target)
|
||||
|
||||
def summarize_codebook_hits(code_hits: torch.Tensor) -> tuple[float, float, int]:
|
||||
total_hits = code_hits.sum()
|
||||
if total_hits.item() <= 0:
|
||||
return 0.0, 0.0, 0
|
||||
|
||||
probs = code_hits / total_hits
|
||||
perplexity = torch.exp(
|
||||
-(probs * torch.log(probs.clamp_min(1e-10))).sum()
|
||||
).item()
|
||||
active_codes = int((code_hits > 0).sum().item())
|
||||
usage_rate = active_codes / code_hits.numel()
|
||||
return perplexity, usage_rate, active_codes
|
||||
|
||||
def quantize_stage_latents(
|
||||
quantizers: tuple[VectorQuantizer, VectorQuantizer, VectorQuantizer],
|
||||
z_e1: torch.Tensor,
|
||||
z_e2: torch.Tensor,
|
||||
z_e3: torch.Tensor
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
quantizer1, quantizer2, quantizer3 = quantizers
|
||||
z_q1, _, commit_loss1, _, code_hits1 = quantizer1(z_e1)
|
||||
z_q2, _, commit_loss2, _, code_hits2 = quantizer2(z_e2)
|
||||
z_q3, _, commit_loss3, _, code_hits3 = quantizer3(z_e3)
|
||||
|
||||
commit_loss = (commit_loss1 + commit_loss2 + commit_loss3) / 3
|
||||
code_hits = torch.stack([code_hits1, code_hits2, code_hits3], dim=0)
|
||||
return (z_q1, z_q2, z_q3), commit_loss, code_hits
|
||||
|
||||
def build_reference_rollout_steps(prob: float) -> int:
|
||||
if random.random() >= prob:
|
||||
return 0
|
||||
|
||||
return random.randint(1, 3)
|
||||
|
||||
def sample_reference_inputs(
|
||||
model: torch.nn.Module,
|
||||
reference: torch.Tensor,
|
||||
z_q: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
stage: int,
|
||||
rollout_steps: int
|
||||
) -> torch.Tensor:
|
||||
if rollout_steps <= 0:
|
||||
return reference
|
||||
|
||||
sampled_reference = reference.clone()
|
||||
with torch.no_grad():
|
||||
current = sampled_reference.clone()
|
||||
z_q_detached = z_q.detach()
|
||||
|
||||
for _ in range(rollout_steps):
|
||||
masked_positions = current == MASK_TOKEN
|
||||
masked_counts = masked_positions.sum(dim=1)
|
||||
if int(masked_counts.sum().item()) <= 0:
|
||||
break
|
||||
|
||||
remain = compute_remaining(current, target_density, stage)
|
||||
logits = model(current, z_q_detached, struct, remain)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
predicted = dist.sample()
|
||||
confidence = torch.gather(
|
||||
probs,
|
||||
-1,
|
||||
predicted.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
for local_idx in range(current.size(0)):
|
||||
masked_count = int(masked_counts[local_idx].item())
|
||||
if masked_count <= 0:
|
||||
continue
|
||||
|
||||
masked_indices = masked_positions[local_idx].nonzero(as_tuple=True)[0]
|
||||
reveal_count = max(1, math.ceil(masked_count * 0.1))
|
||||
reveal_count = min(reveal_count, masked_indices.numel())
|
||||
masked_confidence = confidence[local_idx, masked_indices]
|
||||
_, reveal_order = torch.topk(
|
||||
masked_confidence,
|
||||
k=reveal_count,
|
||||
largest=True
|
||||
)
|
||||
reveal_indices = masked_indices[reveal_order]
|
||||
current[local_idx, reveal_indices] = predicted[local_idx, reveal_indices]
|
||||
|
||||
sampled_reference = current
|
||||
|
||||
return sampled_reference
|
||||
|
||||
def random_struct(device: torch.device) -> torch.Tensor:
|
||||
# 随机采样一组结构参量,用于无条件自由生成
|
||||
@ -336,14 +424,17 @@ def full_generate_random_z(
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
quantizer1, quantizer2, quantizer3 = quantizers
|
||||
|
||||
with torch.no_grad():
|
||||
z = quantizer.sample(1, VQ_L, device)
|
||||
z1 = quantizer1.sample(1, VQ_L, device)
|
||||
z2 = quantizer2.sample(1, VQ_L, device)
|
||||
z3 = quantizer3.sample(1, VQ_L, device)
|
||||
|
||||
# stage1:生成墙壁骨架
|
||||
pred1_np = maskgit_sample(
|
||||
mg1, input.clone(), z, struct, target_density, 1,
|
||||
mg1, input.clone(), z1, struct, target_density, 1,
|
||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||
)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
@ -351,7 +442,7 @@ def full_generate_random_z(
|
||||
|
||||
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z, struct, target_density, 2,
|
||||
mg2, inp2, z2, struct, target_density, 2,
|
||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
@ -361,7 +452,7 @@ def full_generate_random_z(
|
||||
|
||||
# stage3:填充 resource(3)
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z, struct, target_density, 3,
|
||||
mg3, inp3, z3, struct, target_density, 3,
|
||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
@ -371,26 +462,28 @@ def full_generate_random_z(
|
||||
|
||||
def full_generate_specific_z(
|
||||
input: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
struct: torch.Tensor,
|
||||
target_density: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
) -> tuple:
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
|
||||
with torch.no_grad():
|
||||
z1, z2, z3 = z_q
|
||||
|
||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||
pred1_np = maskgit_sample(
|
||||
mg1, input.clone(), z, struct, target_density, 1,
|
||||
mg1, input.clone(), z1, struct, target_density, 1,
|
||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||
)
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN
|
||||
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z, struct, target_density, 2,
|
||||
mg2, inp2, z2, struct, target_density, 2,
|
||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
@ -399,7 +492,7 @@ def full_generate_specific_z(
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z, struct, target_density, 3,
|
||||
mg3, inp3, z3, struct, target_density, 3,
|
||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
@ -493,9 +586,10 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
target_density_t = batch["target_density"][0:1].to(device)
|
||||
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
|
||||
kf = rand_keep()
|
||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, target_density_t, models, device, keep_fixed=kf
|
||||
inp1_t, z_q_single, struct_t, target_density_t, models, device, keep_fixed=kf
|
||||
)
|
||||
kf_label = 'fix' if kf[0] else 'free'
|
||||
|
||||
@ -665,6 +759,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
base_target_density = batch["target_density"][0:1].to(device)
|
||||
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
|
||||
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
gen_imgs = []
|
||||
wall_count_values = [20, 35, 50, 65, 80]
|
||||
@ -673,7 +768,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
|
||||
target_density_cpu = fixed_target_density[0].cpu()
|
||||
_, _, merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, fixed_target_density, models, device
|
||||
inp1_t, z_q_single, struct_t, fixed_target_density, models, device
|
||||
)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
||||
@ -695,7 +790,8 @@ def validate(
|
||||
density_stats: dict,
|
||||
epoch: int
|
||||
):
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
quantizer1, quantizer2, quantizer3 = quantizers
|
||||
|
||||
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
@ -706,6 +802,7 @@ def validate(
|
||||
loss2_total = torch.Tensor([0]).to(device)
|
||||
loss3_total = torch.Tensor([0]).to(device)
|
||||
commit_total = torch.Tensor([0]).to(device)
|
||||
code_hits_total = torch.zeros(3, quantizer1.K, device=device)
|
||||
|
||||
density_metrics = {
|
||||
1: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||
@ -736,27 +833,30 @@ def validate(
|
||||
struct = batch["struct_inject"].to(device)
|
||||
target_density = batch["target_density"].to(device)
|
||||
|
||||
# VQ 编码:各阶段独立编码后拼接、量化
|
||||
# VQ 编码:各阶段独立编码并分别量化
|
||||
z_e1 = vq1(enc1) # [B, L, d_z]
|
||||
z_e2 = vq2(enc2)
|
||||
z_e3 = vq3(enc3)
|
||||
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
z_q, commit_loss, code_hits = quantize_stage_latents(
|
||||
quantizers, z_e1, z_e2, z_e3
|
||||
)
|
||||
z_q1, z_q2, z_q3 = z_q
|
||||
|
||||
remain1 = compute_remaining(inp1, target_density, 1)
|
||||
remain2 = compute_remaining(inp2, target_density, 2)
|
||||
remain3 = compute_remaining(inp3, target_density, 3)
|
||||
|
||||
# 三阶段 MaskGIT 推理(均以完整 z_q、struct 和动态 remain 为条件)
|
||||
logits1 = mg1(inp1, z_q, struct, remain1)
|
||||
logits2 = mg2(inp2, z_q, struct, remain2)
|
||||
logits3 = mg3(inp3, z_q, struct, remain3)
|
||||
# 三阶段 MaskGIT 推理:各阶段只接收自己的 z_q
|
||||
logits1 = mg1(inp1, z_q1, struct, remain1)
|
||||
logits2 = mg2(inp2, z_q2, struct, remain2)
|
||||
logits3 = mg3(inp3, z_q3, struct, remain3)
|
||||
|
||||
loss1_total += focal_loss(logits1, target1)
|
||||
loss2_total += focal_loss(logits2, target2)
|
||||
loss3_total += focal_loss(logits3, target3)
|
||||
loss1_total += cross_entropy_loss(logits1, target1)
|
||||
loss2_total += cross_entropy_loss(logits2, target2)
|
||||
loss3_total += cross_entropy_loss(logits3, target3)
|
||||
commit_total += commit_loss
|
||||
code_hits_total += code_hits
|
||||
|
||||
# 计算各目标对象的真实密度误差与过量生成密度
|
||||
pred1_map = torch.argmax(logits1, dim=-1).cpu()
|
||||
@ -806,19 +906,20 @@ def validate(
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
m.train()
|
||||
|
||||
return loss1_total, loss2_total, loss3_total, commit_total
|
||||
return loss1_total, loss2_total, loss3_total, commit_total, code_hits_total
|
||||
|
||||
def train(device: torch.device):
|
||||
args = parse_arguments()
|
||||
|
||||
models = build_model(device)
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||
quantizer1, quantizer2, quantizer3 = quantizers
|
||||
|
||||
tqdm.write(f"Device: {device}")
|
||||
model_list = [
|
||||
("vq1", vq1), ("vq2", vq2), ("vq3", vq3),
|
||||
("mg1", mg1), ("mg2", mg2), ("mg3", mg3),
|
||||
("quantizer", quantizer)
|
||||
("quantizer1", quantizer1), ("quantizer2", quantizer2), ("quantizer3", quantizer3)
|
||||
]
|
||||
total_params = 0
|
||||
for name, m in model_list:
|
||||
@ -838,7 +939,15 @@ def train(device: torch.device):
|
||||
mg1.load_state_dict(ckpt["mg1"])
|
||||
mg2.load_state_dict(ckpt["mg2"])
|
||||
mg3.load_state_dict(ckpt["mg3"])
|
||||
quantizer.load_state_dict(ckpt["quantizer"])
|
||||
if "quantizer1" in ckpt:
|
||||
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||
elif "quantizer" in ckpt:
|
||||
quantizer1.load_state_dict(ckpt["quantizer"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer"])
|
||||
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
|
||||
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
||||
if args.load_optim and "optimizer" in ckpt:
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
@ -878,6 +987,7 @@ def train(device: torch.device):
|
||||
loss2_total = torch.Tensor([0]).to(device)
|
||||
loss3_total = torch.Tensor([0]).to(device)
|
||||
commit_total = torch.Tensor([0]).to(device)
|
||||
code_hits_total = torch.zeros(3, quantizer1.K, device=device)
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
|
||||
@ -904,26 +1014,39 @@ def train(device: torch.device):
|
||||
z_e2 = vq2(enc2)
|
||||
z_e3 = vq3(enc3)
|
||||
|
||||
# 合并三阶段编码后量化
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
# 三阶段分别量化,各自使用独立 codebook
|
||||
z_q, commit_loss, code_hits = quantize_stage_latents(
|
||||
quantizers, z_e1, z_e2, z_e3
|
||||
)
|
||||
z_q1, z_q2, z_q3 = z_q
|
||||
|
||||
rollout_steps = build_reference_rollout_steps(REFERENCE_SAMPLE_PROB)
|
||||
inp1 = sample_reference_inputs(
|
||||
mg1, inp1, z_q1, struct, target_density, 1, rollout_steps
|
||||
)
|
||||
inp2 = sample_reference_inputs(
|
||||
mg2, inp2, z_q2, struct, target_density, 2, rollout_steps
|
||||
)
|
||||
inp3 = sample_reference_inputs(
|
||||
mg3, inp3, z_q3, struct, target_density, 3, rollout_steps
|
||||
)
|
||||
|
||||
remain1 = compute_remaining(inp1, target_density, 1)
|
||||
remain2 = compute_remaining(inp2, target_density, 2)
|
||||
remain3 = compute_remaining(inp3, target_density, 3)
|
||||
|
||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和动态 remain 条件)
|
||||
logits1 = mg1(inp1, z_q, struct, remain1)
|
||||
logits2 = mg2(inp2, z_q, struct, remain2)
|
||||
logits3 = mg3(inp3, z_q, struct, remain3)
|
||||
# 三阶段 MaskGIT 前向:各阶段只接收自己的 z_q、struct 和动态 remain 条件
|
||||
logits1 = mg1(inp1, z_q1, struct, remain1)
|
||||
logits2 = mg2(inp2, z_q2, struct, remain2)
|
||||
logits3 = mg3(inp3, z_q3, struct, remain3)
|
||||
|
||||
# 三阶段 Focal Loss + VQ commit loss 加权求和
|
||||
loss1 = focal_loss(logits1, target1)
|
||||
loss2 = focal_loss(logits2, target2)
|
||||
loss3 = focal_loss(logits3, target3)
|
||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1
|
||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2
|
||||
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3
|
||||
# 三阶段 Cross Entropy + VQ commit loss 加权求和
|
||||
loss1 = cross_entropy_loss(logits1, target1)
|
||||
loss2 = cross_entropy_loss(logits2, target2)
|
||||
loss3 = cross_entropy_loss(logits3, target3)
|
||||
loss1_weighted = STAGE1_CE_WEIGHT * loss1
|
||||
loss2_weighted = STAGE2_CE_WEIGHT * loss2
|
||||
loss3_weighted = STAGE3_CE_WEIGHT * loss3
|
||||
commit_weighted = VQ_BETA * commit_loss
|
||||
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
||||
|
||||
@ -936,11 +1059,13 @@ def train(device: torch.device):
|
||||
loss2_total += loss2.detach()
|
||||
loss3_total += loss3.detach()
|
||||
commit_total += commit_loss.detach()
|
||||
code_hits_total += code_hits.detach()
|
||||
|
||||
# 每个 epoch 结束后更新学习率
|
||||
scheduler.step()
|
||||
|
||||
data_length = len(dataloader)
|
||||
train_perplexity, train_usage_rate, train_active_codes = summarize_codebook_hits(code_hits_total)
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
||||
@ -948,20 +1073,23 @@ def train(device: torch.device):
|
||||
f"L2: {loss2_total.item() / data_length:.6f} | "
|
||||
f"L3: {loss3_total.item() / data_length:.6f} | "
|
||||
f"VQ: {commit_total.item() / data_length:.6f} | "
|
||||
f"PPL: {train_perplexity:.4f} | "
|
||||
f"Usage: {train_usage_rate:.4f} ({train_active_codes}/{code_hits_total.numel()}) | "
|
||||
f"LR: {scheduler.get_last_lr()[0]:.6f}"
|
||||
)
|
||||
|
||||
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
||||
if (epoch + 1) % CHECKPOINT == 0:
|
||||
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
|
||||
loss1_total, loss2_total, loss3_total, commit_total = losses
|
||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
|
||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
|
||||
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total
|
||||
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
|
||||
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
|
||||
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total
|
||||
loss3_weighted = STAGE3_CE_WEIGHT * loss3_total
|
||||
commit_weighted = VQ_BETA * commit_total
|
||||
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
||||
|
||||
data_length = len(dataloader_val)
|
||||
val_perplexity, val_usage_rate, val_active_codes = summarize_codebook_hits(code_hits_total)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
||||
@ -969,6 +1097,8 @@ def train(device: torch.device):
|
||||
f"L2: {loss2_total.item() / data_length:.6f} | "
|
||||
f"L3: {loss3_total.item() / data_length:.6f} | "
|
||||
f"VQ: {commit_total.item() / data_length:.6f} | "
|
||||
f"PPL: {val_perplexity:.4f} | "
|
||||
f"Usage: {val_usage_rate:.4f} ({val_active_codes}/{code_hits_total.numel()}) | "
|
||||
)
|
||||
|
||||
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
|
||||
@ -980,7 +1110,9 @@ def train(device: torch.device):
|
||||
"mg1": mg1.state_dict(),
|
||||
"mg2": mg2.state_dict(),
|
||||
"mg3": mg3.state_dict(),
|
||||
"quantizer": quantizer.state_dict(),
|
||||
"quantizer1": quantizer1.state_dict(),
|
||||
"quantizer2": quantizer2.state_dict(),
|
||||
"quantizer3": quantizer3.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
}, ckpt_path)
|
||||
@ -996,7 +1128,9 @@ def train(device: torch.device):
|
||||
"mg1": mg1.state_dict(),
|
||||
"mg2": mg2.state_dict(),
|
||||
"mg3": mg3.state_dict(),
|
||||
"quantizer": quantizer.state_dict(),
|
||||
"quantizer1": quantizer1.state_dict(),
|
||||
"quantizer2": quantizer2.state_dict(),
|
||||
"quantizer3": quantizer3.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
}, final_path)
|
||||
|
||||
@ -18,8 +18,23 @@ class VectorQuantizer(nn.Module):
|
||||
self.codebook = nn.Embedding(K, d_z)
|
||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||
|
||||
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# z_e: [B, L * 3, d_z]
|
||||
def codebook_stats(
|
||||
self, indices: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
flat_indices = indices.reshape(-1)
|
||||
one_hot = F.one_hot(flat_indices, num_classes=self.K).float()
|
||||
avg_probs = one_hot.mean(dim=0)
|
||||
perplexity = torch.exp(
|
||||
-(avg_probs * torch.log(avg_probs.clamp_min(1e-10))).sum()
|
||||
)
|
||||
usage_rate = (avg_probs > 0).float().mean()
|
||||
usage_count = one_hot.sum(dim=0)
|
||||
return perplexity, usage_rate, usage_count
|
||||
|
||||
def forward(
|
||||
self, z_e: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# z_e: [B, L, d_z]
|
||||
"""
|
||||
Args:
|
||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
||||
@ -31,7 +46,7 @@ class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
B, L, d_z = z_e.shape
|
||||
|
||||
z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z]
|
||||
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
|
||||
|
||||
codebook_w = self.codebook.weight # [K, d_z]
|
||||
|
||||
@ -56,13 +71,9 @@ class VectorQuantizer(nn.Module):
|
||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||
|
||||
indices = indices.reshape(B, L)
|
||||
return z_q_st, indices, commit_loss
|
||||
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
|
||||
return z_q_st, indices, commit_loss, perplexity, usage_count
|
||||
|
||||
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
|
||||
indices1 = torch.randint(0, self.K, (B, L), device=device)
|
||||
indices2 = torch.randint(0, self.K, (B, L), device=device)
|
||||
indices3 = torch.randint(0, self.K, (B, L), device=device)
|
||||
z1 = self.codebook(indices1)
|
||||
z2 = self.codebook(indices2)
|
||||
z3 = self.codebook(indices3)
|
||||
return torch.cat([z1, z2, z3], dim=1)
|
||||
indices = torch.randint(0, self.K, (B, L), device=device)
|
||||
return self.codebook(indices)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user