mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 06:51:11 +08:00
151 lines
5.8 KiB
Python
151 lines
5.8 KiB
Python
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
from ..utils import print_memory
|
||
from .maskGIT import Transformer
|
||
|
||
# 结构标签词表大小
|
||
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7
|
||
ROOM_VOCAB = 3 # roomCountLevel 0-2
|
||
BRANCH_VOCAB = 3 # branchLevel 0-2
|
||
OUTER_VOCAB = 2 # outerWall 0-1
|
||
|
||
# 密度标签词表大小(Low/Medium/High 三档)
|
||
DOOR_DENSITY_VOCAB = 3
|
||
MONSTER_DENSITY_VOCAB = 3
|
||
RESOURCE_DENSITY_VOCAB = 3
|
||
|
||
class GinkaMaskGIT(nn.Module):
|
||
def __init__(
|
||
self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512,
|
||
nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13,
|
||
d_z: int = 64, z_seq_len: int = 6
|
||
):
|
||
super().__init__()
|
||
self.map_h = map_h
|
||
self.map_w = map_w
|
||
|
||
# Tile 嵌入 + 二维因式分解位置编码
|
||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||
self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02)
|
||
self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02)
|
||
|
||
# 结构标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token
|
||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||
|
||
# 密度标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token
|
||
self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z)
|
||
self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z)
|
||
self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z)
|
||
|
||
# z 投影:逐 token 线性变换,保持序列结构
|
||
self.z_proj = nn.Linear(d_z, d_z)
|
||
|
||
# 条件融合投影:将 (z_seq_len + 4 + 3) 个 d_z 维 token 拼接后降维到 d_model
|
||
# 拼接顺序:z_seq_len 个 z token + 4 个结构 token + 3 个密度 token
|
||
self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model)
|
||
|
||
# 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层
|
||
self.transformer = Transformer(
|
||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||
)
|
||
|
||
self.output_fc = nn.Linear(d_model, num_classes)
|
||
|
||
def forward(
|
||
self,
|
||
map: torch.Tensor,
|
||
z: torch.Tensor,
|
||
struct: torch.Tensor,
|
||
density: torch.Tensor
|
||
) -> torch.Tensor:
|
||
# map: [B, H * W]
|
||
# z: [B, z_seq_len, d_z]
|
||
# struct: [B, 4]
|
||
# density: [B, 3] — [door_level, monster_level, resource_level]
|
||
|
||
# 结构标签:各自嵌入为独立 token,stack 成序列 [B, 4, d_z]
|
||
e_struct = torch.stack([
|
||
self.sym_embed(struct[:, 0]),
|
||
self.room_embed(struct[:, 1]),
|
||
self.branch_embed(struct[:, 2]),
|
||
self.outer_embed(struct[:, 3])
|
||
], dim=1)
|
||
|
||
# 密度标签:各自嵌入为独立 token,stack 成序列 [B, 3, d_z]
|
||
e_density = torch.stack([
|
||
self.door_density_embed(density[:, 0]),
|
||
self.monster_density_embed(density[:, 1]),
|
||
self.resource_density_embed(density[:, 2])
|
||
], dim=1)
|
||
|
||
# z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
|
||
z_proj = self.z_proj(z)
|
||
|
||
# 拼接所有条件 token → [B, z_seq_len+7, d_z],展平后投影到 d_model
|
||
cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1)
|
||
c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model]
|
||
|
||
# tile embedding + 位置编码
|
||
row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w)
|
||
col_idx = torch.arange(self.map_w, device=map.device).repeat(self.map_h)
|
||
pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] # [H*W, d_model]
|
||
x = self.tile_embedding(map) + pos # [B, H * W, d_model]
|
||
|
||
# Transformer:纯 encoder,每层通过 AdaLN 接收全局条件向量 c
|
||
x = self.transformer(x, c) # [B, H * W, d_model]
|
||
|
||
logits = self.output_fc(x) # [B, H * W, num_classes]
|
||
return logits
|
||
|
||
if __name__ == "__main__":
|
||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||
|
||
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169]
|
||
z_input = torch.randn(4, 6, 64).to(device) # [4, L*3, 64]
|
||
struct_input = torch.tensor([
|
||
[3, 1, 0, 1],
|
||
[0, 2, 1, 0],
|
||
[5, 1, 2, 1],
|
||
[1, 0, 1, 0],
|
||
], dtype=torch.long).to(device) # [4, 4]
|
||
density_input = torch.tensor([
|
||
[0, 1, 2],
|
||
[2, 0, 1],
|
||
[1, 2, 0],
|
||
[0, 0, 1],
|
||
], dtype=torch.long).to(device) # [4, 3]
|
||
|
||
model = GinkaMaskGIT(
|
||
num_classes=7,
|
||
d_model=256,
|
||
d_z=64,
|
||
dim_ff=1024,
|
||
nhead=4,
|
||
num_layers=6,
|
||
map_h=13,
|
||
map_w=13,
|
||
z_seq_len=6
|
||
).to(device)
|
||
|
||
print_memory(device, "初始化后")
|
||
|
||
start = time.perf_counter()
|
||
logits = model(map_input, z_input, struct_input, density_input)
|
||
end = time.perf_counter()
|
||
|
||
print_memory(device, "前向传播后")
|
||
|
||
print(f"推理耗时: {end - start:.4f}s")
|
||
print(f"输出形状: logits={logits.shape}")
|
||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||
print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}")
|
||
print(f"Density Projection parameters: {sum(p.numel() for p in model.density_proj.parameters())}")
|
||
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
|
||
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||
print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
|
||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|