mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 18:31:13 +08:00
feat: 调整码字容量与策略
This commit is contained in:
parent
306d585a28
commit
416aa4dd72
@ -28,13 +28,13 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||||
|
|
||||||
# 剩余密度投影:将 5 个浮点数 [wall, door, monster, entrance, resource] 投影为 d_z 维 token
|
# 剩余密度投影:将 5 个浮点数 [wall, door, monster, entrance, resource] 投影为 d_z 维 token
|
||||||
self.remain_proj = nn.Linear(5, d_z)
|
self.remain_proj = nn.Linear(1, d_z)
|
||||||
|
|
||||||
# z 投影:逐 token 线性变换,保持序列结构
|
# z 投影:逐 token 线性变换,保持序列结构
|
||||||
self.z_proj = nn.Linear(d_z, d_z)
|
self.z_proj = nn.Linear(d_z, d_z)
|
||||||
|
|
||||||
# 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个剩余密度 token
|
# 条件融合投影:z_seq_len 个 z token + 2 个结构 token + 1 个剩余密度 token
|
||||||
self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model)
|
self.cond_proj = nn.Linear((z_seq_len + 2 + 5) * d_z, d_model)
|
||||||
|
|
||||||
# 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层
|
# 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层
|
||||||
self.transformer = Transformer(
|
self.transformer = Transformer(
|
||||||
@ -62,7 +62,7 @@ class GinkaMaskGIT(nn.Module):
|
|||||||
], dim=1)
|
], dim=1)
|
||||||
|
|
||||||
# 剩余密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z]
|
# 剩余密度:连续浮点向量投影为单个 d_z 维 token,[B, 1, d_z]
|
||||||
e_remain = self.remain_proj(remain).unsqueeze(1)
|
e_remain = self.remain_proj(remain.unsqueeze(-1))
|
||||||
|
|
||||||
# z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
|
# z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
|
||||||
z_proj = self.z_proj(z)
|
z_proj = self.z_proj(z)
|
||||||
|
|||||||
@ -23,8 +23,8 @@ from shared.image import matrix_to_image_cv
|
|||||||
#
|
#
|
||||||
# 整体架构:
|
# 整体架构:
|
||||||
# VQ-VAE(三组独立编码器 vq1/vq2/vq3)将三阶段地图上下文分别编码为离散潜变量,
|
# VQ-VAE(三组独立编码器 vq1/vq2/vq3)将三阶段地图上下文分别编码为离散潜变量,
|
||||||
# 再由共用 VectorQuantizer 统一量化为 z_q;
|
# 再由三个独立 VectorQuantizer 分别量化为 z_q1/z_q2/z_q3;
|
||||||
# 三个独立 MaskGIT(mg1/mg2/mg3)分别以 z_q 和 struct_inject 为条件,
|
# 三个独立 MaskGIT(mg1/mg2/mg3)分别以各自阶段 z_q 和 struct_inject 为条件,
|
||||||
# 逐阶段迭代解码地图图块序列。
|
# 逐阶段迭代解码地图图块序列。
|
||||||
#
|
#
|
||||||
# 三阶段生成目标:
|
# 三阶段生成目标:
|
||||||
@ -37,14 +37,14 @@ from shared.image import matrix_to_image_cv
|
|||||||
|
|
||||||
# 共用 VQ-VAE 超参
|
# 共用 VQ-VAE 超参
|
||||||
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码
|
||||||
VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3)
|
||||||
VQ_K = 8 # codebook 大小(离散码本条目数)
|
VQ_K = 16 # codebook 大小(离散码本条目数)
|
||||||
VQ_D_Z = 64 # 码字维度
|
VQ_D_Z = 64 # 码字维度
|
||||||
VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook)
|
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||||
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
||||||
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
|
VQ_LAYERS = 6 # VQ-VAE Transformer 层数
|
||||||
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
|
VQ_DIM_FF = 1024 # VQ-VAE 前馈网络隐层维度
|
||||||
VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度
|
VQ_D_MODEL = 256 # VQ-VAE Transformer 模型维度
|
||||||
VQ_NHEAD = 4 # VQ-VAE 多头注意力头数
|
VQ_NHEAD = 4 # VQ-VAE 多头注意力头数
|
||||||
|
|
||||||
# 第一阶段 MaskGIT 超参
|
# 第一阶段 MaskGIT 超参
|
||||||
@ -65,10 +65,10 @@ STAGE3_MG_NHEAD = 4
|
|||||||
STAGE3_MG_NUM_LAYERS = 6
|
STAGE3_MG_NUM_LAYERS = 6
|
||||||
STAGE3_MG_DIM_FF = 1024
|
STAGE3_MG_DIM_FF = 1024
|
||||||
|
|
||||||
# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例)
|
# 三阶段 Cross Entropy 损失权重(可调节各阶段对总损失的贡献比例)
|
||||||
STAGE1_FOCAL_WEIGHT = 1.0
|
STAGE1_CE_WEIGHT = 1.0
|
||||||
STAGE2_FOCAL_WEIGHT = 1.0
|
STAGE2_CE_WEIGHT = 1.0
|
||||||
STAGE3_FOCAL_WEIGHT = 1.0
|
STAGE3_CE_WEIGHT = 1.0
|
||||||
|
|
||||||
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
|
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
|
||||||
STAGE1_VQ_WEIGHT = 0.5
|
STAGE1_VQ_WEIGHT = 0.5
|
||||||
@ -95,7 +95,6 @@ MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
|
|||||||
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
|
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
|
||||||
|
|
||||||
# 损失参数
|
# 损失参数
|
||||||
FOCAL_GAMMA = 2.0 # Focal Loss 参数
|
|
||||||
VQ_BETA = 0.5 # 承诺损失权重
|
VQ_BETA = 0.5 # 承诺损失权重
|
||||||
|
|
||||||
# 训练超参
|
# 训练超参
|
||||||
@ -105,6 +104,7 @@ MIN_LR = 1e-6 # 余弦退火最低学习率
|
|||||||
WEIGHT_DECAY = 1e-4 # L2 正则化系数
|
WEIGHT_DECAY = 1e-4 # L2 正则化系数
|
||||||
EPOCHS = 400 # 总训练轮数
|
EPOCHS = 400 # 总训练轮数
|
||||||
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
||||||
|
REFERENCE_SAMPLE_PROB = 0.2 # 训练时将参考掩码图无梯度自采样 1-3 步的概率
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:0" if torch.cuda.is_available()
|
"cuda:0" if torch.cuda.is_available()
|
||||||
@ -131,7 +131,7 @@ def parse_arguments():
|
|||||||
|
|
||||||
def build_model(device: torch.device):
|
def build_model(device: torch.device):
|
||||||
# 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3)
|
# 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3)
|
||||||
# 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer
|
# 输出形状均为 [B, L, d_z],分别送入各自阶段的 quantizer
|
||||||
vq_kwargs = dict(
|
vq_kwargs = dict(
|
||||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
|
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
|
||||||
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_h=MAP_H, map_w=MAP_W
|
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_h=MAP_H, map_w=MAP_W
|
||||||
@ -140,21 +140,21 @@ def build_model(device: torch.device):
|
|||||||
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance)
|
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance)
|
||||||
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文(resource)
|
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文(resource)
|
||||||
|
|
||||||
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
|
# 三个独立 MaskGIT 解码器,分别接收各自阶段的 z_q 作为条件
|
||||||
mg1 = GinkaMaskGIT(
|
mg1 = GinkaMaskGIT(
|
||||||
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
|
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
|
||||||
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||||
z_seq_len=VQ_L * 3
|
z_seq_len=VQ_L
|
||||||
).to(device)
|
).to(device)
|
||||||
mg2 = GinkaMaskGIT(
|
mg2 = GinkaMaskGIT(
|
||||||
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
|
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
|
||||||
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||||
z_seq_len=VQ_L * 3
|
z_seq_len=VQ_L
|
||||||
).to(device)
|
).to(device)
|
||||||
mg3 = GinkaMaskGIT(
|
mg3 = GinkaMaskGIT(
|
||||||
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
|
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
|
||||||
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||||
z_seq_len=VQ_L * 3
|
z_seq_len=VQ_L
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
# 六个模型参数合并到同一优化器,端到端联合训练
|
||||||
@ -166,18 +166,106 @@ def build_model(device: torch.device):
|
|||||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||||
|
|
||||||
# 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表
|
# 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰
|
||||||
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
|
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
|
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
|
quantizers = (quantizer1, quantizer2, quantizer3)
|
||||||
|
|
||||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
|
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
|
||||||
|
|
||||||
def focal_loss(logits, target):
|
def cross_entropy_loss(logits, target):
|
||||||
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
|
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
|
||||||
ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none')
|
return F.cross_entropy(logits.permute(0, 2, 1), target)
|
||||||
pt = torch.exp(-ce) # pt = 模型对正确类的预测概率
|
|
||||||
# Focal Loss:对高置信度样本降低权重,让模型更专注于难样本
|
def summarize_codebook_hits(code_hits: torch.Tensor) -> tuple[float, float, int]:
|
||||||
focal = ((1 - pt) ** FOCAL_GAMMA) * ce
|
total_hits = code_hits.sum()
|
||||||
return focal.mean()
|
if total_hits.item() <= 0:
|
||||||
|
return 0.0, 0.0, 0
|
||||||
|
|
||||||
|
probs = code_hits / total_hits
|
||||||
|
perplexity = torch.exp(
|
||||||
|
-(probs * torch.log(probs.clamp_min(1e-10))).sum()
|
||||||
|
).item()
|
||||||
|
active_codes = int((code_hits > 0).sum().item())
|
||||||
|
usage_rate = active_codes / code_hits.numel()
|
||||||
|
return perplexity, usage_rate, active_codes
|
||||||
|
|
||||||
|
def quantize_stage_latents(
|
||||||
|
quantizers: tuple[VectorQuantizer, VectorQuantizer, VectorQuantizer],
|
||||||
|
z_e1: torch.Tensor,
|
||||||
|
z_e2: torch.Tensor,
|
||||||
|
z_e3: torch.Tensor
|
||||||
|
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||||
|
quantizer1, quantizer2, quantizer3 = quantizers
|
||||||
|
z_q1, _, commit_loss1, _, code_hits1 = quantizer1(z_e1)
|
||||||
|
z_q2, _, commit_loss2, _, code_hits2 = quantizer2(z_e2)
|
||||||
|
z_q3, _, commit_loss3, _, code_hits3 = quantizer3(z_e3)
|
||||||
|
|
||||||
|
commit_loss = (commit_loss1 + commit_loss2 + commit_loss3) / 3
|
||||||
|
code_hits = torch.stack([code_hits1, code_hits2, code_hits3], dim=0)
|
||||||
|
return (z_q1, z_q2, z_q3), commit_loss, code_hits
|
||||||
|
|
||||||
|
def build_reference_rollout_steps(prob: float) -> int:
|
||||||
|
if random.random() >= prob:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return random.randint(1, 3)
|
||||||
|
|
||||||
|
def sample_reference_inputs(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
reference: torch.Tensor,
|
||||||
|
z_q: torch.Tensor,
|
||||||
|
struct: torch.Tensor,
|
||||||
|
target_density: torch.Tensor,
|
||||||
|
stage: int,
|
||||||
|
rollout_steps: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if rollout_steps <= 0:
|
||||||
|
return reference
|
||||||
|
|
||||||
|
sampled_reference = reference.clone()
|
||||||
|
with torch.no_grad():
|
||||||
|
current = sampled_reference.clone()
|
||||||
|
z_q_detached = z_q.detach()
|
||||||
|
|
||||||
|
for _ in range(rollout_steps):
|
||||||
|
masked_positions = current == MASK_TOKEN
|
||||||
|
masked_counts = masked_positions.sum(dim=1)
|
||||||
|
if int(masked_counts.sum().item()) <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
remain = compute_remaining(current, target_density, stage)
|
||||||
|
logits = model(current, z_q_detached, struct, remain)
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
dist = torch.distributions.Categorical(probs)
|
||||||
|
predicted = dist.sample()
|
||||||
|
confidence = torch.gather(
|
||||||
|
probs,
|
||||||
|
-1,
|
||||||
|
predicted.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
for local_idx in range(current.size(0)):
|
||||||
|
masked_count = int(masked_counts[local_idx].item())
|
||||||
|
if masked_count <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
masked_indices = masked_positions[local_idx].nonzero(as_tuple=True)[0]
|
||||||
|
reveal_count = max(1, math.ceil(masked_count * 0.1))
|
||||||
|
reveal_count = min(reveal_count, masked_indices.numel())
|
||||||
|
masked_confidence = confidence[local_idx, masked_indices]
|
||||||
|
_, reveal_order = torch.topk(
|
||||||
|
masked_confidence,
|
||||||
|
k=reveal_count,
|
||||||
|
largest=True
|
||||||
|
)
|
||||||
|
reveal_indices = masked_indices[reveal_order]
|
||||||
|
current[local_idx, reveal_indices] = predicted[local_idx, reveal_indices]
|
||||||
|
|
||||||
|
sampled_reference = current
|
||||||
|
|
||||||
|
return sampled_reference
|
||||||
|
|
||||||
def random_struct(device: torch.device) -> torch.Tensor:
|
def random_struct(device: torch.device) -> torch.Tensor:
|
||||||
# 随机采样一组结构参量,用于无条件自由生成
|
# 随机采样一组结构参量,用于无条件自由生成
|
||||||
@ -336,14 +424,17 @@ def full_generate_random_z(
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||||
|
quantizer1, quantizer2, quantizer3 = quantizers
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z = quantizer.sample(1, VQ_L, device)
|
z1 = quantizer1.sample(1, VQ_L, device)
|
||||||
|
z2 = quantizer2.sample(1, VQ_L, device)
|
||||||
|
z3 = quantizer3.sample(1, VQ_L, device)
|
||||||
|
|
||||||
# stage1:生成墙壁骨架
|
# stage1:生成墙壁骨架
|
||||||
pred1_np = maskgit_sample(
|
pred1_np = maskgit_sample(
|
||||||
mg1, input.clone(), z, struct, target_density, 1,
|
mg1, input.clone(), z1, struct, target_density, 1,
|
||||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||||
)
|
)
|
||||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||||
@ -351,7 +442,7 @@ def full_generate_random_z(
|
|||||||
|
|
||||||
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
||||||
pred2_np = maskgit_sample(
|
pred2_np = maskgit_sample(
|
||||||
mg2, inp2, z, struct, target_density, 2,
|
mg2, inp2, z2, struct, target_density, 2,
|
||||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||||
)
|
)
|
||||||
merged12 = pred1_np.copy()
|
merged12 = pred1_np.copy()
|
||||||
@ -361,7 +452,7 @@ def full_generate_random_z(
|
|||||||
|
|
||||||
# stage3:填充 resource(3)
|
# stage3:填充 resource(3)
|
||||||
pred3_np = maskgit_sample(
|
pred3_np = maskgit_sample(
|
||||||
mg3, inp3, z, struct, target_density, 3,
|
mg3, inp3, z3, struct, target_density, 3,
|
||||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||||
)
|
)
|
||||||
merged123 = merged12.copy()
|
merged123 = merged12.copy()
|
||||||
@ -371,26 +462,28 @@ def full_generate_random_z(
|
|||||||
|
|
||||||
def full_generate_specific_z(
|
def full_generate_specific_z(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
z: torch.Tensor,
|
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
struct: torch.Tensor,
|
struct: torch.Tensor,
|
||||||
target_density: torch.Tensor,
|
target_density: torch.Tensor,
|
||||||
models: list[torch.nn.Module],
|
models: list[torch.nn.Module],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
z1, z2, z3 = z_q
|
||||||
|
|
||||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||||
pred1_np = maskgit_sample(
|
pred1_np = maskgit_sample(
|
||||||
mg1, input.clone(), z, struct, target_density, 1,
|
mg1, input.clone(), z1, struct, target_density, 1,
|
||||||
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
|
||||||
)
|
)
|
||||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||||
inp2[inp2 == 0] = MASK_TOKEN
|
inp2[inp2 == 0] = MASK_TOKEN
|
||||||
|
|
||||||
pred2_np = maskgit_sample(
|
pred2_np = maskgit_sample(
|
||||||
mg2, inp2, z, struct, target_density, 2,
|
mg2, inp2, z2, struct, target_density, 2,
|
||||||
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||||
)
|
)
|
||||||
merged12 = pred1_np.copy()
|
merged12 = pred1_np.copy()
|
||||||
@ -399,7 +492,7 @@ def full_generate_specific_z(
|
|||||||
inp3[inp3 == 0] = MASK_TOKEN
|
inp3[inp3 == 0] = MASK_TOKEN
|
||||||
|
|
||||||
pred3_np = maskgit_sample(
|
pred3_np = maskgit_sample(
|
||||||
mg3, inp3, z, struct, target_density, 3,
|
mg3, inp3, z3, struct, target_density, 3,
|
||||||
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||||
)
|
)
|
||||||
merged123 = merged12.copy()
|
merged123 = merged12.copy()
|
||||||
@ -493,9 +586,10 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
|||||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||||
struct_t = batch["struct_inject"][0:1].to(device)
|
struct_t = batch["struct_inject"][0:1].to(device)
|
||||||
target_density_t = batch["target_density"][0:1].to(device)
|
target_density_t = batch["target_density"][0:1].to(device)
|
||||||
|
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
|
||||||
kf = rand_keep()
|
kf = rand_keep()
|
||||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||||
inp1_t, z_q[0:1], struct_t, target_density_t, models, device, keep_fixed=kf
|
inp1_t, z_q_single, struct_t, target_density_t, models, device, keep_fixed=kf
|
||||||
)
|
)
|
||||||
kf_label = 'fix' if kf[0] else 'free'
|
kf_label = 'fix' if kf[0] else 'free'
|
||||||
|
|
||||||
@ -665,6 +759,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
|
|||||||
struct_t = batch["struct_inject"][0:1].to(device)
|
struct_t = batch["struct_inject"][0:1].to(device)
|
||||||
struct_cpu = batch["struct_inject"][0]
|
struct_cpu = batch["struct_inject"][0]
|
||||||
base_target_density = batch["target_density"][0:1].to(device)
|
base_target_density = batch["target_density"][0:1].to(device)
|
||||||
|
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
|
||||||
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||||
gen_imgs = []
|
gen_imgs = []
|
||||||
wall_count_values = [20, 35, 50, 65, 80]
|
wall_count_values = [20, 35, 50, 65, 80]
|
||||||
@ -673,7 +768,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
|
|||||||
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
|
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
|
||||||
target_density_cpu = fixed_target_density[0].cpu()
|
target_density_cpu = fixed_target_density[0].cpu()
|
||||||
_, _, merged123 = full_generate_specific_z(
|
_, _, merged123 = full_generate_specific_z(
|
||||||
inp1_t, z_q[0:1], struct_t, fixed_target_density, models, device
|
inp1_t, z_q_single, struct_t, fixed_target_density, models, device
|
||||||
)
|
)
|
||||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
|
||||||
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
||||||
@ -695,7 +790,8 @@ def validate(
|
|||||||
density_stats: dict,
|
density_stats: dict,
|
||||||
epoch: int
|
epoch: int
|
||||||
):
|
):
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||||
|
quantizer1, quantizer2, quantizer3 = quantizers
|
||||||
|
|
||||||
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
|
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
|
||||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||||
@ -706,6 +802,7 @@ def validate(
|
|||||||
loss2_total = torch.Tensor([0]).to(device)
|
loss2_total = torch.Tensor([0]).to(device)
|
||||||
loss3_total = torch.Tensor([0]).to(device)
|
loss3_total = torch.Tensor([0]).to(device)
|
||||||
commit_total = torch.Tensor([0]).to(device)
|
commit_total = torch.Tensor([0]).to(device)
|
||||||
|
code_hits_total = torch.zeros(3, quantizer1.K, device=device)
|
||||||
|
|
||||||
density_metrics = {
|
density_metrics = {
|
||||||
1: {"mae": 0.0, "over": 0.0, "count": 0},
|
1: {"mae": 0.0, "over": 0.0, "count": 0},
|
||||||
@ -736,27 +833,30 @@ def validate(
|
|||||||
struct = batch["struct_inject"].to(device)
|
struct = batch["struct_inject"].to(device)
|
||||||
target_density = batch["target_density"].to(device)
|
target_density = batch["target_density"].to(device)
|
||||||
|
|
||||||
# VQ 编码:各阶段独立编码后拼接、量化
|
# VQ 编码:各阶段独立编码并分别量化
|
||||||
z_e1 = vq1(enc1) # [B, L, d_z]
|
z_e1 = vq1(enc1) # [B, L, d_z]
|
||||||
z_e2 = vq2(enc2)
|
z_e2 = vq2(enc2)
|
||||||
z_e3 = vq3(enc3)
|
z_e3 = vq3(enc3)
|
||||||
|
|
||||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
z_q, commit_loss, code_hits = quantize_stage_latents(
|
||||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
quantizers, z_e1, z_e2, z_e3
|
||||||
|
)
|
||||||
|
z_q1, z_q2, z_q3 = z_q
|
||||||
|
|
||||||
remain1 = compute_remaining(inp1, target_density, 1)
|
remain1 = compute_remaining(inp1, target_density, 1)
|
||||||
remain2 = compute_remaining(inp2, target_density, 2)
|
remain2 = compute_remaining(inp2, target_density, 2)
|
||||||
remain3 = compute_remaining(inp3, target_density, 3)
|
remain3 = compute_remaining(inp3, target_density, 3)
|
||||||
|
|
||||||
# 三阶段 MaskGIT 推理(均以完整 z_q、struct 和动态 remain 为条件)
|
# 三阶段 MaskGIT 推理:各阶段只接收自己的 z_q
|
||||||
logits1 = mg1(inp1, z_q, struct, remain1)
|
logits1 = mg1(inp1, z_q1, struct, remain1)
|
||||||
logits2 = mg2(inp2, z_q, struct, remain2)
|
logits2 = mg2(inp2, z_q2, struct, remain2)
|
||||||
logits3 = mg3(inp3, z_q, struct, remain3)
|
logits3 = mg3(inp3, z_q3, struct, remain3)
|
||||||
|
|
||||||
loss1_total += focal_loss(logits1, target1)
|
loss1_total += cross_entropy_loss(logits1, target1)
|
||||||
loss2_total += focal_loss(logits2, target2)
|
loss2_total += cross_entropy_loss(logits2, target2)
|
||||||
loss3_total += focal_loss(logits3, target3)
|
loss3_total += cross_entropy_loss(logits3, target3)
|
||||||
commit_total += commit_loss
|
commit_total += commit_loss
|
||||||
|
code_hits_total += code_hits
|
||||||
|
|
||||||
# 计算各目标对象的真实密度误差与过量生成密度
|
# 计算各目标对象的真实密度误差与过量生成密度
|
||||||
pred1_map = torch.argmax(logits1, dim=-1).cpu()
|
pred1_map = torch.argmax(logits1, dim=-1).cpu()
|
||||||
@ -806,19 +906,20 @@ def validate(
|
|||||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
return loss1_total, loss2_total, loss3_total, commit_total
|
return loss1_total, loss2_total, loss3_total, commit_total, code_hits_total
|
||||||
|
|
||||||
def train(device: torch.device):
|
def train(device: torch.device):
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
models = build_model(device)
|
models = build_model(device)
|
||||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
|
||||||
|
quantizer1, quantizer2, quantizer3 = quantizers
|
||||||
|
|
||||||
tqdm.write(f"Device: {device}")
|
tqdm.write(f"Device: {device}")
|
||||||
model_list = [
|
model_list = [
|
||||||
("vq1", vq1), ("vq2", vq2), ("vq3", vq3),
|
("vq1", vq1), ("vq2", vq2), ("vq3", vq3),
|
||||||
("mg1", mg1), ("mg2", mg2), ("mg3", mg3),
|
("mg1", mg1), ("mg2", mg2), ("mg3", mg3),
|
||||||
("quantizer", quantizer)
|
("quantizer1", quantizer1), ("quantizer2", quantizer2), ("quantizer3", quantizer3)
|
||||||
]
|
]
|
||||||
total_params = 0
|
total_params = 0
|
||||||
for name, m in model_list:
|
for name, m in model_list:
|
||||||
@ -838,7 +939,15 @@ def train(device: torch.device):
|
|||||||
mg1.load_state_dict(ckpt["mg1"])
|
mg1.load_state_dict(ckpt["mg1"])
|
||||||
mg2.load_state_dict(ckpt["mg2"])
|
mg2.load_state_dict(ckpt["mg2"])
|
||||||
mg3.load_state_dict(ckpt["mg3"])
|
mg3.load_state_dict(ckpt["mg3"])
|
||||||
quantizer.load_state_dict(ckpt["quantizer"])
|
if "quantizer1" in ckpt:
|
||||||
|
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||||
|
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||||
|
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||||
|
elif "quantizer" in ckpt:
|
||||||
|
quantizer1.load_state_dict(ckpt["quantizer"])
|
||||||
|
quantizer2.load_state_dict(ckpt["quantizer"])
|
||||||
|
quantizer3.load_state_dict(ckpt["quantizer"])
|
||||||
|
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
|
||||||
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
||||||
if args.load_optim and "optimizer" in ckpt:
|
if args.load_optim and "optimizer" in ckpt:
|
||||||
optimizer.load_state_dict(ckpt["optimizer"])
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
@ -878,6 +987,7 @@ def train(device: torch.device):
|
|||||||
loss2_total = torch.Tensor([0]).to(device)
|
loss2_total = torch.Tensor([0]).to(device)
|
||||||
loss3_total = torch.Tensor([0]).to(device)
|
loss3_total = torch.Tensor([0]).to(device)
|
||||||
commit_total = torch.Tensor([0]).to(device)
|
commit_total = torch.Tensor([0]).to(device)
|
||||||
|
code_hits_total = torch.zeros(3, quantizer1.K, device=device)
|
||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
|
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
|
||||||
@ -904,26 +1014,39 @@ def train(device: torch.device):
|
|||||||
z_e2 = vq2(enc2)
|
z_e2 = vq2(enc2)
|
||||||
z_e3 = vq3(enc3)
|
z_e3 = vq3(enc3)
|
||||||
|
|
||||||
# 合并三阶段编码后量化
|
# 三阶段分别量化,各自使用独立 codebook
|
||||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
z_q, commit_loss, code_hits = quantize_stage_latents(
|
||||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
quantizers, z_e1, z_e2, z_e3
|
||||||
|
)
|
||||||
|
z_q1, z_q2, z_q3 = z_q
|
||||||
|
|
||||||
|
rollout_steps = build_reference_rollout_steps(REFERENCE_SAMPLE_PROB)
|
||||||
|
inp1 = sample_reference_inputs(
|
||||||
|
mg1, inp1, z_q1, struct, target_density, 1, rollout_steps
|
||||||
|
)
|
||||||
|
inp2 = sample_reference_inputs(
|
||||||
|
mg2, inp2, z_q2, struct, target_density, 2, rollout_steps
|
||||||
|
)
|
||||||
|
inp3 = sample_reference_inputs(
|
||||||
|
mg3, inp3, z_q3, struct, target_density, 3, rollout_steps
|
||||||
|
)
|
||||||
|
|
||||||
remain1 = compute_remaining(inp1, target_density, 1)
|
remain1 = compute_remaining(inp1, target_density, 1)
|
||||||
remain2 = compute_remaining(inp2, target_density, 2)
|
remain2 = compute_remaining(inp2, target_density, 2)
|
||||||
remain3 = compute_remaining(inp3, target_density, 3)
|
remain3 = compute_remaining(inp3, target_density, 3)
|
||||||
|
|
||||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和动态 remain 条件)
|
# 三阶段 MaskGIT 前向:各阶段只接收自己的 z_q、struct 和动态 remain 条件
|
||||||
logits1 = mg1(inp1, z_q, struct, remain1)
|
logits1 = mg1(inp1, z_q1, struct, remain1)
|
||||||
logits2 = mg2(inp2, z_q, struct, remain2)
|
logits2 = mg2(inp2, z_q2, struct, remain2)
|
||||||
logits3 = mg3(inp3, z_q, struct, remain3)
|
logits3 = mg3(inp3, z_q3, struct, remain3)
|
||||||
|
|
||||||
# 三阶段 Focal Loss + VQ commit loss 加权求和
|
# 三阶段 Cross Entropy + VQ commit loss 加权求和
|
||||||
loss1 = focal_loss(logits1, target1)
|
loss1 = cross_entropy_loss(logits1, target1)
|
||||||
loss2 = focal_loss(logits2, target2)
|
loss2 = cross_entropy_loss(logits2, target2)
|
||||||
loss3 = focal_loss(logits3, target3)
|
loss3 = cross_entropy_loss(logits3, target3)
|
||||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1
|
loss1_weighted = STAGE1_CE_WEIGHT * loss1
|
||||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2
|
loss2_weighted = STAGE2_CE_WEIGHT * loss2
|
||||||
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3
|
loss3_weighted = STAGE3_CE_WEIGHT * loss3
|
||||||
commit_weighted = VQ_BETA * commit_loss
|
commit_weighted = VQ_BETA * commit_loss
|
||||||
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
||||||
|
|
||||||
@ -936,11 +1059,13 @@ def train(device: torch.device):
|
|||||||
loss2_total += loss2.detach()
|
loss2_total += loss2.detach()
|
||||||
loss3_total += loss3.detach()
|
loss3_total += loss3.detach()
|
||||||
commit_total += commit_loss.detach()
|
commit_total += commit_loss.detach()
|
||||||
|
code_hits_total += code_hits.detach()
|
||||||
|
|
||||||
# 每个 epoch 结束后更新学习率
|
# 每个 epoch 结束后更新学习率
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
data_length = len(dataloader)
|
data_length = len(dataloader)
|
||||||
|
train_perplexity, train_usage_rate, train_active_codes = summarize_codebook_hits(code_hits_total)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||||
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
||||||
@ -948,20 +1073,23 @@ def train(device: torch.device):
|
|||||||
f"L2: {loss2_total.item() / data_length:.6f} | "
|
f"L2: {loss2_total.item() / data_length:.6f} | "
|
||||||
f"L3: {loss3_total.item() / data_length:.6f} | "
|
f"L3: {loss3_total.item() / data_length:.6f} | "
|
||||||
f"VQ: {commit_total.item() / data_length:.6f} | "
|
f"VQ: {commit_total.item() / data_length:.6f} | "
|
||||||
|
f"PPL: {train_perplexity:.4f} | "
|
||||||
|
f"Usage: {train_usage_rate:.4f} ({train_active_codes}/{code_hits_total.numel()}) | "
|
||||||
f"LR: {scheduler.get_last_lr()[0]:.6f}"
|
f"LR: {scheduler.get_last_lr()[0]:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
|
||||||
if (epoch + 1) % CHECKPOINT == 0:
|
if (epoch + 1) % CHECKPOINT == 0:
|
||||||
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
|
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
|
||||||
loss1_total, loss2_total, loss3_total, commit_total = losses
|
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
|
||||||
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
|
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
|
||||||
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
|
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total
|
||||||
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total
|
loss3_weighted = STAGE3_CE_WEIGHT * loss3_total
|
||||||
commit_weighted = VQ_BETA * commit_total
|
commit_weighted = VQ_BETA * commit_total
|
||||||
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
|
||||||
|
|
||||||
data_length = len(dataloader_val)
|
data_length = len(dataloader_val)
|
||||||
|
val_perplexity, val_usage_rate, val_active_codes = summarize_codebook_hits(code_hits_total)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||||
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
|
||||||
@ -969,6 +1097,8 @@ def train(device: torch.device):
|
|||||||
f"L2: {loss2_total.item() / data_length:.6f} | "
|
f"L2: {loss2_total.item() / data_length:.6f} | "
|
||||||
f"L3: {loss3_total.item() / data_length:.6f} | "
|
f"L3: {loss3_total.item() / data_length:.6f} | "
|
||||||
f"VQ: {commit_total.item() / data_length:.6f} | "
|
f"VQ: {commit_total.item() / data_length:.6f} | "
|
||||||
|
f"PPL: {val_perplexity:.4f} | "
|
||||||
|
f"Usage: {val_usage_rate:.4f} ({val_active_codes}/{code_hits_total.numel()}) | "
|
||||||
)
|
)
|
||||||
|
|
||||||
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
|
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
|
||||||
@ -980,7 +1110,9 @@ def train(device: torch.device):
|
|||||||
"mg1": mg1.state_dict(),
|
"mg1": mg1.state_dict(),
|
||||||
"mg2": mg2.state_dict(),
|
"mg2": mg2.state_dict(),
|
||||||
"mg3": mg3.state_dict(),
|
"mg3": mg3.state_dict(),
|
||||||
"quantizer": quantizer.state_dict(),
|
"quantizer1": quantizer1.state_dict(),
|
||||||
|
"quantizer2": quantizer2.state_dict(),
|
||||||
|
"quantizer3": quantizer3.state_dict(),
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"scheduler": scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
@ -996,7 +1128,9 @@ def train(device: torch.device):
|
|||||||
"mg1": mg1.state_dict(),
|
"mg1": mg1.state_dict(),
|
||||||
"mg2": mg2.state_dict(),
|
"mg2": mg2.state_dict(),
|
||||||
"mg3": mg3.state_dict(),
|
"mg3": mg3.state_dict(),
|
||||||
"quantizer": quantizer.state_dict(),
|
"quantizer1": quantizer1.state_dict(),
|
||||||
|
"quantizer2": quantizer2.state_dict(),
|
||||||
|
"quantizer3": quantizer3.state_dict(),
|
||||||
"optimizer": optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"scheduler": scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
}, final_path)
|
}, final_path)
|
||||||
|
|||||||
@ -18,8 +18,23 @@ class VectorQuantizer(nn.Module):
|
|||||||
self.codebook = nn.Embedding(K, d_z)
|
self.codebook = nn.Embedding(K, d_z)
|
||||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||||
|
|
||||||
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def codebook_stats(
|
||||||
# z_e: [B, L * 3, d_z]
|
self, indices: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
flat_indices = indices.reshape(-1)
|
||||||
|
one_hot = F.one_hot(flat_indices, num_classes=self.K).float()
|
||||||
|
avg_probs = one_hot.mean(dim=0)
|
||||||
|
perplexity = torch.exp(
|
||||||
|
-(avg_probs * torch.log(avg_probs.clamp_min(1e-10))).sum()
|
||||||
|
)
|
||||||
|
usage_rate = (avg_probs > 0).float().mean()
|
||||||
|
usage_count = one_hot.sum(dim=0)
|
||||||
|
return perplexity, usage_rate, usage_count
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, z_e: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# z_e: [B, L, d_z]
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
||||||
@ -31,7 +46,7 @@ class VectorQuantizer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, L, d_z = z_e.shape
|
B, L, d_z = z_e.shape
|
||||||
|
|
||||||
z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z]
|
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
|
||||||
|
|
||||||
codebook_w = self.codebook.weight # [K, d_z]
|
codebook_w = self.codebook.weight # [K, d_z]
|
||||||
|
|
||||||
@ -56,13 +71,9 @@ class VectorQuantizer(nn.Module):
|
|||||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||||
|
|
||||||
indices = indices.reshape(B, L)
|
indices = indices.reshape(B, L)
|
||||||
return z_q_st, indices, commit_loss
|
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
|
||||||
|
return z_q_st, indices, commit_loss, perplexity, usage_count
|
||||||
|
|
||||||
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
|
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
|
||||||
indices1 = torch.randint(0, self.K, (B, L), device=device)
|
indices = torch.randint(0, self.K, (B, L), device=device)
|
||||||
indices2 = torch.randint(0, self.K, (B, L), device=device)
|
return self.codebook(indices)
|
||||||
indices3 = torch.randint(0, self.K, (B, L), device=device)
|
|
||||||
z1 = self.codebook(indices1)
|
|
||||||
z2 = self.codebook(indices2)
|
|
||||||
z3 = self.codebook(indices3)
|
|
||||||
return torch.cat([z1, z2, z3], dim=1)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user