ginka-generator/ginka/maskGIT/model.py
unanmed 3676958781 feat: 分三阶段训练
Co-authored-by: Copilot <copilot@github.com>
2026-05-07 20:59:22 +08:00

200 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import torch
import torch.nn as nn
from ..utils import print_memory
from .maskGIT import Transformer
# 结构标签词表大小(最后一个索引为无条件占位符 null
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-67 = null
ROOM_VOCAB = 4 # roomCountLevel 0-23 = null
BRANCH_VOCAB = 4 # branchLevel 0-23 = null
OUTER_VOCAB = 3 # outerWall 0-12 = null
class GinkaMaskGIT(nn.Module):
"""
改造后的 MaskGIT 地图生成模型。
以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入,
通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别。
z 通过 cross-attention 注入到 Transformer decoder
作为风格/多样性控制信号,而非结构重建指导。
"""
def __init__(
self,
num_classes: int = 16,
d_model: int = 192,
d_z: int = 64,
dim_ff: int = 512,
nhead: int = 8,
num_layers: int = 4,
map_size: int = 13 * 13,
z_dropout: float = 0.1,
struct_dropout: float = 0.15,
):
"""
Args:
num_classes: tile 类别数(含 MASK token=15
d_model: Transformer 内部维度
d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致
dim_ff: 前馈网络隐层维度
nhead: 注意力头数
num_layers: Transformer 层数
map_size: 地图 token 总数H * W
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
struct_dropout: 训练时以此概率将结构标签替换为 null无条件占位
实现 classifier-free guidance 兼容训练
"""
super().__init__()
self.z_dropout = z_dropout
self.struct_dropout_prob = struct_dropout
# Tile 嵌入 + 位置编码
self.tile_embedding = nn.Embedding(num_classes, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
self.z_proj = nn.Sequential(
nn.Linear(d_z, d_model * 2),
nn.LayerNorm(d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model)
)
# 结构标签嵌入(编码到 d_z 维度)
# 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
self.struct_proj = nn.Sequential(
nn.Linear(d_z, d_model * 2),
nn.LayerNorm(d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model)
)
# Transformerencoder 做 map token 自注意力decoder 做与 z 的 cross-attention
self.transformer = Transformer(
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
)
self.output_fc = nn.Linear(d_model, num_classes)
def forward(
self,
map: torch.Tensor,
z: torch.Tensor,
struct_cond: torch.Tensor | None = None,
dropout_struct: bool = False,
) -> torch.Tensor:
"""
Args:
map: [B, H*W] 掩码后的地图 token 序列MASK token = 15
z: [B, L_total, d_z] VQ-VAE 量化后的离散隐变量;
方案 B 中 L_total = L1+L2+L3三路 z 拼接)
struct_cond: [B, 4] 结构标签 LongTensor顺序为
[cond_sym, cond_room, cond_branch, cond_outer]
为 None 时等价于全 null无条件模式
dropout_struct: bool 强制将所有结构标签替换为 null推理时无条件生成
Returns:
logits: [B, H*W, num_classes]
"""
B = z.shape[0]
# z dropout训练时以一定概率将 z 替换为随机均匀噪声,
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
if self.training and self.z_dropout > 0:
mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout
rand_z = torch.randn_like(z)
z = torch.where(mask, rand_z, z)
# 结构标签嵌入
# struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引
if struct_cond is None or dropout_struct:
sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device)
room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device)
branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device)
outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device)
else:
sc = struct_cond.to(z.device)
sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3]
# 训练时对各标签独立做 struct dropout
if self.training and self.struct_dropout_prob > 0:
def _drop(idx, null_val):
drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob
return torch.where(drop_mask, torch.full_like(idx, null_val), idx)
sym_idx = _drop(sym_idx, SYM_VOCAB - 1)
room_idx = _drop(room_idx, ROOM_VOCAB - 1)
branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1)
outer_idx = _drop(outer_idx, OUTER_VOCAB - 1)
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
# VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接
z_mem_vq = self.z_proj(z) # [B, L, d_model]
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L+4, d_model]
# tile embedding + 位置编码
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
# Transformerencoder 做 map 自注意力decoder cross-attend z+struct
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
logits = self.output_fc(x) # [B, H*W, num_classes]
return logits
if __name__ == "__main__":
device = torch.device("cpu")
model = GinkaMaskGIT(
num_classes=16,
d_model=192,
d_z=64,
dim_ff=512,
nhead=8,
num_layers=4,
map_size=13 * 13,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
struct_input = torch.tensor([[3, 1, 0, 1],
[0, 2, 1, 0],
[7, 3, 3, 2],
[1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4]
model.train()
logits = model(map_input, z_input, struct_cond=struct_input)
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
# 无条件模式测试
logits_uncond = model(map_input, z_input, struct_cond=None)
print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16]
print_memory(device, "前向传播后")