From bf3d24e680b46932ffb45fe720d8a0a04bcdb690 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 15 May 2026 18:15:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20AdaLN=20=E6=9D=A1=E4=BB=B6=E6=B3=A8?= =?UTF-8?q?=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/film-adaln-cond-design.md | 212 +++++++++++++++++++++++++++++++++ ginka/maskGIT/maskGIT.py | 76 ++++++++---- ginka/maskGIT/model.py | 119 ++++++++---------- ginka/train_seperated.py | 11 +- 4 files changed, 321 insertions(+), 97 deletions(-) create mode 100644 docs/film-adaln-cond-design.md diff --git a/docs/film-adaln-cond-design.md b/docs/film-adaln-cond-design.md new file mode 100644 index 0000000..7595076 --- /dev/null +++ b/docs/film-adaln-cond-design.md @@ -0,0 +1,212 @@ +# 条件注入方式改进:从 Cross Attention 到 FiLM / AdaLN + +## 问题背景 + +### 当前条件注入方式 + +`GinkaMaskGIT` 当前使用的条件注入策略如下: + +1. VQ 码字 `z`(形状 `[B, L*3, d_z]`)通过 `z_proj` 投影到 `d_model` 维度 +2. 结构标签(`sym / room / branch / outer`)各自嵌入后拼接为 `[B, 4, d_model]` +3. 密度标签(`door / monster / resource`)三个嵌入相加后经 MLP 得到 `[B, 1, d_model]` +4. 上述三部分拼接为 `memory`(`[B, L*3+5, d_model]`),作为 cross-attention 的 key/value +5. Transformer decoder 以 map token 作为 query,对 `memory` 做 cross-attention + +### 问题分析 + +Cross-attention 的本质是**查询驱动**(query-driven)的检索机制:模型只在需要时才主动去 `memory` 中寻找相关信息,且注意力权重由 query(地图 token)与 key 的相似度决定。 + +这一机制对**空间局部条件**(如参考图像特征、空间先验)效果良好,但对**全局标量条件**(如"资源密度为 High")存在以下问题: + +#### 1. 隐式性:条件无法强制生效 + +模型可以选择性地"忽视"某个 memory 条目。结构/密度条件只是 memory 序列中的几个 token,与 VQ 码字并列竞争注意力权重。当 VQ 码字已经携带了足够多的生成信息时,模型倾向于将注意力集中在 VQ 码字上,而对结构/密度 token 的注意力权重趋近于零。 + +实验现象印证了这一点:即使将密度标签设置为 High,模型生成的怪物/资源数量与 Low 时差异极小,说明密度条件被模型基本忽略。 + +#### 2. 语义不匹配:全局信号与局部查询不对齐 + +Cross-attention 的设计假设 key/value 携带**空间位置相关**的信息(例如编码器输出的特征图),query 在不同位置关注不同的 key。然而: + +- 密度标签是一个全局标量(表示整张地图的资源密度档位),没有空间维度 +- 所有地图位置(169 个 token)的 query 若都要接收该全局信号,需要所有 query 一致地高度关注同一个 key,这与 cross-attention 的设计初衷相悖 + +#### 3. 与 VQ 码字竞争导致梯度稀释 + +结构/密度条件作为 memory token,与 VQ 码字通过同一个 softmax 竞争注意力。当 VQ 码字数量远多于条件 token(当前 L\*3=6 对 5),且 VQ 码字携带了更多"有用信息"时,梯度信号倾向于强化对 VQ 的关注,条件 token 的参数得不到有效更新。 + +#### 4. VQ 码字 z 本身也未被充分利用 + +即使将结构/密度从 cross-attention 中移出,VQ 码字 `z` 本身也存在相同的问题。训练前期观察到模型倾向于输出高度相似的地图(风格单一、多样性极低),这表明模型并未有效利用随机采样的 `z`。根本原因相同:cross-attention 是 query-driven 的,模型可以在不关注 `z` 的情况下仅靠地图 token 自注意力完成预测,`z` 的梯度信号因此极为稀弱。因此,`z` 同样需要改为全局 AdaLN 注入,而非仅依赖 cross-attention。 + +--- + +## 改进方案 + +### 核心思路 + +全局条件(结构标签、密度标签)应当作用于 **每一层的特征变换**,以加法偏移或缩放仿射的形式强制施加到所有 map token 上,使模型**无法绕过**该条件。这正是 FiLM 和 AdaLN 的设计目标。 + +### FiLM(Feature-wise Linear Modulation) + +FiLM 对特征向量做逐元素仿射变换: + +$$ +\text{FiLM}(x, c) = \gamma(c) \odot x + \beta(c) +$$ + +其中 $\gamma(c)$ 和 $\beta(c)$ 是从条件 $c$ 预测出的缩放和偏移向量(维度均为 `d_model`),$\odot$ 为逐元素乘法。 + +FiLM 直接修改特征分布,条件信号强制影响所有 token 的表示,而不依赖 query 主动发起的检索。 + +### AdaLN(Adaptive Layer Normalization) + +AdaLN 将 FiLM 与 LayerNorm 结合,用条件向量预测 LayerNorm 的缩放和偏移参数,替代原有的固定参数: + +$$ +\text{AdaLN}(x, c) = \gamma(c) \odot \frac{x - \mu}{\sigma} + \beta(c) +$$ + +与标准 LayerNorm 的区别仅在于 $\gamma$ 和 $\beta$ 不是可学习的静态参数,而是由条件 $c$ 动态生成。AdaLN 在 DiT(Diffusion Transformer)和 MaskGIT 的改进版本中已有广泛验证。 + +**选用 AdaLN** 作为主要方案,理由: + +- 在 Transformer 架构中,LayerNorm 是特征归一化的核心节点,在此处注入条件效果最稳定 +- AdaLN 的参数量增加极少(仅新增 `2 * d_model` 的线性层输出) +- 与 FiLM 效果等价,但更符合 Transformer 的设计惯例 + +--- + +## 架构设计 + +### 条件向量的构建 + +将结构标签、密度标签和 VQ 码字 `z` 全部融合为**单一全局条件向量** `c`(维度 `d_model`),通过 AdaLN 在每一层强制施加到所有 map token 上。 + +**结构标签**(4 个离散标量)各自独立嵌入后**拼接**,再经 Linear 投影: + +``` +struct: [B, 4] → 各自 Embedding(d_cond) → cat → [B, 4*d_cond] → Linear → [B, d_model] +``` + +**密度标签**(3 个离散标量)各自独立嵌入后**拼接**,再经 Linear 投影(不使用相加,避免各档位嵌入相互抵消): + +``` +density: [B, 3] → 各自 Embedding(d_cond) → cat → [B, 3*d_cond] → Linear → [B, d_model] +``` + +**VQ 码字 z**(序列)先做均值池化压缩为单个向量,再经 Linear 投影: + +``` +z: [B, L*3, d_z] → mean(dim=1) → [B, d_z] → Linear → [B, d_model] +``` + +三路向量相加得到最终条件向量: + +``` +c = struct_vec + density_vec + z_vec # [B, d_model] +``` + +> 说明:`z` 改为全局注入的动机在于,训练前期模型观察到输出地图高度相似、多样性极低,表明 cross-attention 方式下模型未能有效利用随机采样的 `z`。均值池化保留了 `z` 序列的整体语义,同时将其压缩为标量条件,适合 AdaLN 注入。 + +### 自定义 Transformer 层 + +由于 PyTorch 的 `nn.TransformerEncoderLayer` / `nn.TransformerDecoderLayer` 不支持外部注入 AdaLN 参数,需要自行实现: + +#### AdaLN 模块 + +```python +class AdaLN(nn.Module): + # 自适应 LayerNorm:用条件向量 c 预测 LayerNorm 的 gamma 和 beta + def __init__(self, d_model: int, d_cond: int): + ... + self.norm = nn.LayerNorm(d_model, elementwise_affine=False) + self.proj = nn.Linear(d_cond, d_model * 2) # 输出 [gamma, beta] + + def forward(self, x, c): + # x: [B, S, d_model] + # c: [B, d_model] 全局条件向量 + gamma, beta = self.proj(c).chunk(2, dim=-1) # 各 [B, d_model] + return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1) +``` + +#### 自定义 Transformer 层 + +替换标准的 `TransformerEncoderLayer`,在每个 sub-layer 的 LayerNorm 处注入条件: + +```python +class CondTransformerLayer(nn.Module): + # 带 AdaLN 条件注入的 Transformer Encoder 层 + # 结构:AdaLN-Self-Attn → AdaLN-FFN + def __init__(self, d_model, nhead, dim_ff, d_cond): + ... + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.adaln1 = AdaLN(d_model, d_cond) # 自注意力前的归一化 + self.adaln2 = AdaLN(d_model, d_cond) # FFN 前的归一化 + self.ffn = nn.Sequential(Linear, GELU, Linear) + + def forward(self, x, c, key_padding_mask=None): + # Pre-norm 结构 + residual = x + x = self.adaln1(x, c) + x, _ = self.self_attn(x, x, x) + x = residual + x + + residual = x + x = self.adaln2(x, c) + x = self.ffn(x) + x = residual + x + return x +``` + +#### Cross-attention 层(移除) + +`z` 已改为通过均值池化后加入全局条件向量 `c`,由 AdaLN 注入每一层,不再需要单独的 cross-attention 层。整个 Transformer 退化为纯 encoder(自注意力)结构,仅由 `CondTransformerLayer` 堆叠而成,无 decoder。 + +### 整体前向流程 + +``` +map → tile_embed + pos_embed → x [B, H*W, d_model] + +struct: [B, 4] → 各自 Embed → cat → Linear → [B, d_model] +density: [B, 3] → 各自 Embed → cat → Linear → [B, d_model] +z: [B, L*3, d_z] → mean → Linear → [B, d_model] +c = struct_vec + density_vec + z_vec # [B, d_model] + +for each layer: + x = CondTransformerLayer(x, c) # AdaLN 自注意力,纯 encoder 结构 + +logits = output_fc(x) [B, H*W, num_classes] +``` + +--- + +## 参数量对比 + +以 `d_model=256, nhead=4, dim_ff=1024, num_layers=6` 为基准估算: + +| 模块 | 当前方案 | 新方案(AdaLN) | +| ------------------- | ------------------------ | -------------------------------------------- | +| 条件嵌入层 | 小(各 Embedding + MLP) | 小(相似,略有增加) | +| 每层 AdaLN 额外参数 | 0 | `2 * d_model * d_model = 131K` × 6 层 ≈ 786K | +| cross-attention 层 | 6 层完整 decoder | 0(移除,z 改为 AdaLN 全局注入) | +| 总参数量变化 | 基准 | +约 5~10%(可接受) | + +--- + +## 实现文件规划 + +| 文件 | 改动内容 | +| -------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | +| `ginka/maskGIT/maskGIT.py` | 重写 `Transformer` 为自定义纯 encoder 架构,新增 `AdaLN`、`CondTransformerLayer`;移除 `ZCrossAttentionLayer` | +| `ginka/maskGIT/model.py` | 更新 `GinkaMaskGIT`:struct/density/z 三路融合为条件向量 `c`,密度标签改为拼接,z 改为均值池化后注入;移除旧 cross-attention 路径 | +| `ginka/train_seperated.py` | 无需修改(接口不变,`forward` 签名保持) | + +--- + +## 预期效果 + +- 密度标签、结构标签、VQ 码字 `z` 三路均通过 AdaLN 在每一层强制影响特征分布,模型无法绕过任何一路条件 +- 密度标签改为拼接(而非相加),避免不同档位嵌入线性叠加时相互抵消,使各密度维度保持独立的表示空间 +- `z` 通过均值池化压缩为全局向量后注入,保留 codebook 多样性的同时消除对 cross-attention 的依赖,预期解决训练前期输出地图高度相似的问题 +- 架构简化为纯 encoder,去掉 encoder-decoder 分离结构,降低实现复杂度和计算量 diff --git a/ginka/maskGIT/maskGIT.py b/ginka/maskGIT/maskGIT.py index 51cf5b8..58fe1f8 100644 --- a/ginka/maskGIT/maskGIT.py +++ b/ginka/maskGIT/maskGIT.py @@ -1,29 +1,57 @@ +import torch import torch.nn as nn -class Transformer(nn.Module): - def __init__( - self, d_model=256, dim_ff=512, nhead=8, num_layers=4, - ): +class AdaLN(nn.Module): + # 自适应 LayerNorm:条件向量 c 动态预测 LayerNorm 的 gamma 和 beta + def __init__(self, d_model: int, d_cond: int): super().__init__() - self.encoder = nn.TransformerEncoder( - nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'), - num_layers=num_layers + self.norm = nn.LayerNorm(d_model, elementwise_affine=False) + self.proj = nn.Linear(d_cond, d_model * 2) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B, S, d_model] c: [B, d_cond] + gamma, beta = self.proj(c).chunk(2, dim=-1) # 各 [B, d_model] + return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1) + +class CondTransformerLayer(nn.Module): + # 带 AdaLN 条件注入的 Transformer Encoder 层 + # 结构:AdaLN → Self-Attn → 残差;AdaLN → FFN → 残差(Pre-norm) + def __init__(self, d_model: int, nhead: int, dim_ff: int, d_cond: int): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.adaln1 = AdaLN(d_model, d_cond) + self.adaln2 = AdaLN(d_model, d_cond) + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_ff), + nn.GELU(), + nn.Linear(dim_ff, d_model) ) - self.decoder = nn.TransformerDecoder( - nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'), - num_layers=num_layers - ) - - def forward(self, x, memory=None): - # x: [B, S, d_model] 地图 token 序列 - # memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention - # 若 memory 为 None,则退化为原始自编解码行为(向后兼容) - enc_out = self.encoder(x) - if memory is not None: - # encoder 输出作为 query,z 作为 key/value - out = self.decoder(enc_out, memory) - else: - out = self.decoder(x, enc_out) - return out - \ No newline at end of file + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B, S, d_model] c: [B, d_cond] + residual = x + normed = self.adaln1(x, c) + x, _ = self.self_attn(normed, normed, normed) + x = residual + x + residual = x + x = self.ffn(self.adaln2(x, c)) + x = residual + x + return x + +class Transformer(nn.Module): + # 纯 encoder Transformer,每层使用 AdaLN 注入全局条件向量 c + def __init__( + self, d_model: int = 256, dim_ff: int = 512, + nhead: int = 8, num_layers: int = 4 + ): + super().__init__() + self.layers = nn.ModuleList([ + CondTransformerLayer(d_model=d_model, nhead=nhead, dim_ff=dim_ff, d_cond=d_model) + for _ in range(num_layers) + ]) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B, S, d_model] c: [B, d_model] 全局条件向量 + for layer in self.layers: + x = layer(x, c) + return x diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index bb2390f..2f5764c 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -19,7 +19,7 @@ class GinkaMaskGIT(nn.Module): def __init__( self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512, nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13, - d_z: int = 64 + d_z: int = 64, z_seq_len: int = 6 ): super().__init__() self.map_h = map_h @@ -30,53 +30,31 @@ class GinkaMaskGIT(nn.Module): self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02) self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02) - # z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用 - self.z_proj = nn.Sequential( - nn.Linear(d_z, d_model * 2), - nn.LayerNorm(d_model * 2), - nn.GELU(), - - nn.Linear(d_model * 2, d_model), - nn.LayerNorm(d_model) - ) - - # 结构标签嵌入(编码到 d_z 维度) - # 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用 + # 结构标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z) self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) - self.struct_proj = nn.Sequential( - nn.Linear(d_z, d_model * 2), - nn.LayerNorm(d_model * 2), - nn.GELU(), - - nn.Linear(d_model * 2, d_model), - nn.LayerNorm(d_model) - ) + # 密度标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token + self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z) + self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z) + self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z) - # Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention + # z 投影:逐 token 线性变换,保持序列结构 + self.z_proj = nn.Linear(d_z, d_z) + + # 条件融合投影:将 (z_seq_len + 4 + 3) 个 d_z 维 token 拼接后降维到 d_model + # 拼接顺序:z_seq_len 个 z token + 4 个结构 token + 3 个密度 token + self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model) + + # 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层 self.transformer = Transformer( d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers ) self.output_fc = nn.Linear(d_model, num_classes) - # 密度标签嵌入 + 独立 MLP(与结构路径完全分离) - # 三个密度 embedding 相加后经两层 MLP 映射为单个条件 token - self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z) - self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z) - self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z) - self.density_mlp = nn.Sequential( - nn.Linear(d_z, d_model * 2), - nn.LayerNorm(d_model * 2), - nn.GELU(), - - nn.Linear(d_model * 2, d_model), - nn.LayerNorm(d_model) - ) - def forward( self, map: torch.Tensor, @@ -85,36 +63,31 @@ class GinkaMaskGIT(nn.Module): density: torch.Tensor ) -> torch.Tensor: # map: [B, H * W] - # z: [B, L * 3, d_z] + # z: [B, z_seq_len, d_z] # struct: [B, 4] # density: [B, 3] — [door_level, monster_level, resource_level] - sym_idx = struct[:, 0] - room_idx = struct[:, 1] - branch_idx = struct[:, 2] - outer_idx = struct[:, 3] + # 结构标签:各自嵌入为独立 token,stack 成序列 [B, 4, d_z] + e_struct = torch.stack([ + self.sym_embed(struct[:, 0]), + self.room_embed(struct[:, 1]), + self.branch_embed(struct[:, 2]), + self.outer_embed(struct[:, 3]) + ], dim=1) - # 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾 - e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z] - e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z] - e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z] - e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z] - - struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z] - - # VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接 - z_mem_vq = self.z_proj(z) # [B, L, d_model] - z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model] - - # 密度条件:三个 embedding 相加后经独立 MLP 得到单个条件 token - e_density = ( - self.door_density_embed(density[:, 0]) + - self.monster_density_embed(density[:, 1]) + + # 密度标签:各自嵌入为独立 token,stack 成序列 [B, 3, d_z] + e_density = torch.stack([ + self.door_density_embed(density[:, 0]), + self.monster_density_embed(density[:, 1]), self.resource_density_embed(density[:, 2]) - ) # [B, d_z] - density_token = self.density_mlp(e_density).unsqueeze(1) # [B, 1, d_model] + ], dim=1) - z_mem = torch.cat([z_mem_vq, z_mem_struct, density_token], dim=1) # [B, L*3+5, d_model] + # z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z] + z_proj = self.z_proj(z) + + # 拼接所有条件 token → [B, z_seq_len+7, d_z],展平后投影到 d_model + cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1) + c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model] # tile embedding + 位置编码 row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w) @@ -122,8 +95,8 @@ class GinkaMaskGIT(nn.Module): pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] # [H*W, d_model] x = self.tile_embedding(map) + pos # [B, H * W, d_model] - # Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct - x = self.transformer(x, memory=z_mem) # [B, H * W, d_model] + # Transformer:纯 encoder,每层通过 AdaLN 接收全局条件向量 c + x = self.transformer(x, c) # [B, H * W, d_model] logits = self.output_fc(x) # [B, H * W, num_classes] return logits @@ -132,29 +105,36 @@ if __name__ == "__main__": device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169] - z_input = torch.randn(4, 2, 64).to(device) # [4, 2, 64] + z_input = torch.randn(4, 6, 64).to(device) # [4, L*3, 64] struct_input = torch.tensor([ [3, 1, 0, 1], [0, 2, 1, 0], [5, 1, 2, 1], [1, 0, 1, 0], ], dtype=torch.long).to(device) # [4, 4] + density_input = torch.tensor([ + [0, 1, 2], + [2, 0, 1], + [1, 2, 0], + [0, 0, 1], + ], dtype=torch.long).to(device) # [4, 3] model = GinkaMaskGIT( num_classes=7, - d_model=192, + d_model=256, d_z=64, - dim_ff=2048, - nhead=8, + dim_ff=1024, + nhead=4, num_layers=6, map_h=13, - map_w=13 + map_w=13, + z_seq_len=6 ).to(device) print_memory(device, "初始化后") start = time.perf_counter() - logits = model(map_input, z_input, struct_input) + logits = model(map_input, z_input, struct_input, density_input) end = time.perf_counter() print_memory(device, "前向传播后") @@ -162,8 +142,9 @@ if __name__ == "__main__": print(f"推理耗时: {end - start:.4f}s") print(f"输出形状: logits={logits.shape}") print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") - print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}") print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}") + print(f"Density Projection parameters: {sum(p.numel() for p in model.density_proj.parameters())}") + print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}") print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}") print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index 6588756..4231db2 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -136,15 +136,18 @@ def build_model(device: torch.device): # 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件 mg1 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF, - nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W + nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W, + z_seq_len=VQ_L * 3 ).to(device) mg2 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF, - nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W + nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W, + z_seq_len=VQ_L * 3 ).to(device) mg3 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF, - nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W + nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W, + z_seq_len=VQ_L * 3 ).to(device) # 六个模型参数合并到同一优化器,端到端联合训练 @@ -596,7 +599,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict): inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) struct_t = batch["struct_inject"][0:1].to(device) struct_cpu = batch["struct_inject"][0] - ref_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) + ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) gen_imgs = [] for _ in range(5): rnd_density = random_density(device)