mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
142 lines
5.2 KiB
Python
142 lines
5.2 KiB
Python
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
from ..utils import print_memory
|
||
from .maskGIT import Transformer
|
||
|
||
# 结构标签词表大小
|
||
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7
|
||
ROOM_VOCAB = 3 # roomCountLevel 0-2
|
||
BRANCH_VOCAB = 3 # branchLevel 0-2
|
||
OUTER_VOCAB = 2 # outerWall 0-1
|
||
|
||
class GinkaMaskGIT(nn.Module):
|
||
def __init__(
|
||
self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512,
|
||
nhead: int = 8, num_layers: int = 4, map_size: int = 13 * 13, d_z: int = 64
|
||
):
|
||
"""
|
||
Args:
|
||
num_classes: tile 类别数(含 MASK token=15)
|
||
d_model: Transformer 内部维度
|
||
dim_ff: 前馈网络隐层维度
|
||
nhead: 注意力头数
|
||
num_layers: Transformer 层数
|
||
map_size: 地图 token 总数(H * W)
|
||
"""
|
||
super().__init__()
|
||
|
||
# Tile 嵌入 + 位置编码
|
||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
|
||
|
||
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
|
||
self.z_proj = nn.Sequential(
|
||
nn.Linear(d_z, d_model * 2),
|
||
nn.LayerNorm(d_model * 2),
|
||
nn.GELU(),
|
||
|
||
nn.Linear(d_model * 2, d_model),
|
||
nn.LayerNorm(d_model)
|
||
)
|
||
|
||
# 结构标签嵌入(编码到 d_z 维度)
|
||
# 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用
|
||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||
|
||
self.struct_proj = nn.Sequential(
|
||
nn.Linear(d_z, d_model * 2),
|
||
nn.LayerNorm(d_model * 2),
|
||
nn.GELU(),
|
||
|
||
nn.Linear(d_model * 2, d_model),
|
||
nn.LayerNorm(d_model)
|
||
)
|
||
|
||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||
self.transformer = Transformer(
|
||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||
)
|
||
|
||
self.output_fc = nn.Linear(d_model, num_classes)
|
||
|
||
def forward(
|
||
self,
|
||
map: torch.Tensor,
|
||
z: torch.Tensor,
|
||
struct: torch.Tensor
|
||
) -> torch.Tensor:
|
||
# map: [B, H * W]
|
||
# z: [B, L * 3, d_z]
|
||
# struch: [B, 4]
|
||
|
||
sym_idx = struct[:, 0]
|
||
room_idx = struct[:, 1]
|
||
branch_idx = struct[:, 2]
|
||
outer_idx = struct[:, 3]
|
||
|
||
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
|
||
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
|
||
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
|
||
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
|
||
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||
|
||
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||
|
||
# VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接
|
||
z_mem_vq = self.z_proj(z) # [B, L, d_model]
|
||
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
|
||
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model]
|
||
|
||
# tile embedding + 位置编码
|
||
x = self.tile_embedding(map) # [B, H * W, d_model]
|
||
x = x + self.pos_embedding # [B, H * W, d_model]
|
||
|
||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct
|
||
x = self.transformer(x, memory=z_mem) # [B, H * W, d_model]
|
||
|
||
logits = self.output_fc(x) # [B, H * W, num_classes]
|
||
return logits
|
||
|
||
if __name__ == "__main__":
|
||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||
|
||
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169]
|
||
z_input = torch.randn(4, 2, 64).to(device) # [4, 2, 64]
|
||
struct_input = torch.tensor([
|
||
[3, 1, 0, 1],
|
||
[0, 2, 1, 0],
|
||
[5, 1, 2, 1],
|
||
[1, 0, 1, 0],
|
||
], dtype=torch.long).to(device) # [4, 4]
|
||
|
||
model = GinkaMaskGIT(
|
||
num_classes=7,
|
||
d_model=192,
|
||
d_z=64,
|
||
dim_ff=2048,
|
||
nhead=8,
|
||
num_layers=6,
|
||
map_size=13 * 13,
|
||
).to(device)
|
||
|
||
print_memory(device, "初始化后")
|
||
|
||
start = time.perf_counter()
|
||
logits = model(map_input, z_input, struct_input)
|
||
end = time.perf_counter()
|
||
|
||
print_memory(device, "前向传播后")
|
||
|
||
print(f"推理耗时: {end - start:.4f}s")
|
||
print(f"输出形状: logits={logits.shape}")
|
||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
|
||
print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}")
|
||
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||
print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
|
||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|