ginka-generator/ginka/maskGIT/model.py
2026-05-15 18:15:50 +08:00

151 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import torch
import torch.nn as nn
from ..utils import print_memory
from .maskGIT import Transformer
# 结构标签词表大小
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7
ROOM_VOCAB = 3 # roomCountLevel 0-2
BRANCH_VOCAB = 3 # branchLevel 0-2
OUTER_VOCAB = 2 # outerWall 0-1
# 密度标签词表大小Low/Medium/High 三档)
DOOR_DENSITY_VOCAB = 3
MONSTER_DENSITY_VOCAB = 3
RESOURCE_DENSITY_VOCAB = 3
class GinkaMaskGIT(nn.Module):
def __init__(
self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512,
nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13,
d_z: int = 64, z_seq_len: int = 6
):
super().__init__()
self.map_h = map_h
self.map_w = map_w
# Tile 嵌入 + 二维因式分解位置编码
self.tile_embedding = nn.Embedding(num_classes, d_model)
self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02)
self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02)
# 结构标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
# 密度标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token
self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z)
self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z)
self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z)
# z 投影:逐 token 线性变换,保持序列结构
self.z_proj = nn.Linear(d_z, d_z)
# 条件融合投影:将 (z_seq_len + 4 + 3) 个 d_z 维 token 拼接后降维到 d_model
# 拼接顺序z_seq_len 个 z token + 4 个结构 token + 3 个密度 token
self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model)
# 纯 encoder Transformer条件向量 c 通过 AdaLN 注入每一层
self.transformer = Transformer(
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
)
self.output_fc = nn.Linear(d_model, num_classes)
def forward(
self,
map: torch.Tensor,
z: torch.Tensor,
struct: torch.Tensor,
density: torch.Tensor
) -> torch.Tensor:
# map: [B, H * W]
# z: [B, z_seq_len, d_z]
# struct: [B, 4]
# density: [B, 3] — [door_level, monster_level, resource_level]
# 结构标签:各自嵌入为独立 tokenstack 成序列 [B, 4, d_z]
e_struct = torch.stack([
self.sym_embed(struct[:, 0]),
self.room_embed(struct[:, 1]),
self.branch_embed(struct[:, 2]),
self.outer_embed(struct[:, 3])
], dim=1)
# 密度标签:各自嵌入为独立 tokenstack 成序列 [B, 3, d_z]
e_density = torch.stack([
self.door_density_embed(density[:, 0]),
self.monster_density_embed(density[:, 1]),
self.resource_density_embed(density[:, 2])
], dim=1)
# z逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
z_proj = self.z_proj(z)
# 拼接所有条件 token → [B, z_seq_len+7, d_z],展平后投影到 d_model
cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1)
c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model]
# tile embedding + 位置编码
row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w)
col_idx = torch.arange(self.map_w, device=map.device).repeat(self.map_h)
pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] # [H*W, d_model]
x = self.tile_embedding(map) + pos # [B, H * W, d_model]
# Transformer纯 encoder每层通过 AdaLN 接收全局条件向量 c
x = self.transformer(x, c) # [B, H * W, d_model]
logits = self.output_fc(x) # [B, H * W, num_classes]
return logits
if __name__ == "__main__":
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169]
z_input = torch.randn(4, 6, 64).to(device) # [4, L*3, 64]
struct_input = torch.tensor([
[3, 1, 0, 1],
[0, 2, 1, 0],
[5, 1, 2, 1],
[1, 0, 1, 0],
], dtype=torch.long).to(device) # [4, 4]
density_input = torch.tensor([
[0, 1, 2],
[2, 0, 1],
[1, 2, 0],
[0, 0, 1],
], dtype=torch.long).to(device) # [4, 3]
model = GinkaMaskGIT(
num_classes=7,
d_model=256,
d_z=64,
dim_ff=1024,
nhead=4,
num_layers=6,
map_h=13,
map_w=13,
z_seq_len=6
).to(device)
print_memory(device, "初始化后")
start = time.perf_counter()
logits = model(map_input, z_input, struct_input, density_input)
end = time.perf_counter()
print_memory(device, "前向传播后")
print(f"推理耗时: {end - start:.4f}s")
print(f"输出形状: logits={logits.shape}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}")
print(f"Density Projection parameters: {sum(p.numel() for p in model.density_proj.parameters())}")
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")