mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: RNN 融合编码
This commit is contained in:
parent
1ccac9e60d
commit
c79662089b
@ -3,53 +3,233 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, cond_dim=256, input_dim=256, hidden_dim=1024, num_layers=2):
|
||||
class RNNConditionEncoder(nn.Module):
|
||||
def __init__(self, val_dim=16, output_dim=256, width=13, height=13):
|
||||
super().__init__()
|
||||
|
||||
# 输入部分
|
||||
self.embedding = nn.Embedding(tile_classes, input_dim)
|
||||
self.input_fc = nn.Linear(input_dim, input_dim)
|
||||
# 条件编码
|
||||
|
||||
self.gru = nn.GRU(input_dim + cond_dim, hidden_dim, num_layers, batch_first=True)
|
||||
self.val_fc = nn.Sequential(
|
||||
nn.Linear(val_dim, output_dim),
|
||||
nn.LayerNorm(output_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(output_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, val_cond: torch.Tensor):
|
||||
val_hidden = self.val_fc(val_cond)
|
||||
return self.fusion(val_hidden)
|
||||
|
||||
class GinkaMapPatch(nn.Module):
|
||||
def __init__(self, tile_classes=32, width=13, height=13):
|
||||
super().__init__()
|
||||
|
||||
# 地图局部卷积,用于捕获局部结构信息
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
self.patch_cnn = nn.Sequential(
|
||||
nn.Conv2d(tile_classes, 256, 3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(256, 512, 3, padding=1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.AvgPool2d(kernel_size=(5, 5)),
|
||||
nn.Flatten()
|
||||
)
|
||||
|
||||
def forward(self, map: torch.Tensor, x: int, y: int):
|
||||
"""
|
||||
map: [B, H, W]
|
||||
"""
|
||||
B, H, W = map.shape
|
||||
result = torch.zeros([B, 5, 5], dtype=torch.long, device=map.device)
|
||||
left = x - 2 if x >= 2 else 0
|
||||
right = x + 3 if x < self.width - 2 else self.width
|
||||
top = y - 4 if y >= 4 else 0
|
||||
bottom = y + 1
|
||||
|
||||
res_left = left - (x - 2)
|
||||
res_right = right - (x + 3) + 5
|
||||
res_top = top - (y - 4)
|
||||
res_bottom = 5
|
||||
|
||||
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
||||
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||
|
||||
return self.patch_cnn(result)
|
||||
|
||||
class GinkaTileEmbedding(nn.Module):
|
||||
def __init__(self, tile_classes=32, embed_dim=128):
|
||||
super().__init__()
|
||||
|
||||
# 图块编码,上一次画的图块
|
||||
|
||||
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
||||
|
||||
def forward(self, tile: torch.Tensor):
|
||||
return self.embedding(tile)
|
||||
|
||||
class GinkaPosEmbedding(nn.Module):
|
||||
def __init__(self, width=13, height=13, embed_dim=128):
|
||||
super().__init__()
|
||||
|
||||
# 位置编码
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
self.row_embedding = nn.Embedding(width, embed_dim)
|
||||
self.col_embedding = nn.Embedding(height, embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
row = self.row_embedding(x)
|
||||
col = self.col_embedding(y)
|
||||
|
||||
return row, col
|
||||
|
||||
class GinkaInputFusion(nn.Module):
|
||||
def __init__(self, d_model=128):
|
||||
super().__init__()
|
||||
|
||||
# 使用 Transformer 进行信息整合
|
||||
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True
|
||||
),
|
||||
num_layers=4
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
):
|
||||
"""
|
||||
tile_embed: [B, 128]
|
||||
cond_vec: [B, 256]
|
||||
row_embed: [B, 128]
|
||||
col_embed: [B, 128]
|
||||
patch_vec: [B, 512]
|
||||
"""
|
||||
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||
vec = torch.stack(torch.split(vec, 128, dim=1), dim=1)
|
||||
feat = self.transformer(vec)
|
||||
return torch.mean(feat, dim=1)
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, input_dim=128, hidden_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
# GRU
|
||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
||||
self.fc = nn.Linear(hidden_dim, tile_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
||||
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
|
||||
"""
|
||||
x: [B, T]
|
||||
cond: [B, cond_dim]
|
||||
feat_fusion: [B, input_dim]
|
||||
hidden: [B, hidden_dim]
|
||||
"""
|
||||
B, T = x.shape
|
||||
tile_emb = self.input_fc(self.embedding(x)) # [B, T, input_dim]
|
||||
cond_expand = cond.unsqueeze(1).expand(B, T, cond.shape[-1]) # [B, T, cond_dim]
|
||||
hidden = self.gru(feat_fusion, hidden)
|
||||
logits = self.fc(hidden)
|
||||
return logits, hidden
|
||||
|
||||
class GinkaRNNModel(nn.Module):
|
||||
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.start_tile = start_tile
|
||||
|
||||
self.rnn_hidden = 1024
|
||||
self.tile_classes = 32
|
||||
|
||||
# 模型结构
|
||||
self.cond = RNNConditionEncoder()
|
||||
self.tile_embedding = GinkaTileEmbedding()
|
||||
self.pos_embedding = GinkaPosEmbedding()
|
||||
self.map_patch = GinkaMapPatch()
|
||||
self.feat_fusion = GinkaInputFusion()
|
||||
self.rnn = GinkaRNN(hidden_dim=self.rnn_hidden)
|
||||
|
||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
|
||||
"""
|
||||
val_cond: [B, val_dim]
|
||||
target_map: [B, H, W]
|
||||
use_self: 是否使用自己生成的上一步结果执行下一步
|
||||
"""
|
||||
B, C = val_cond.shape
|
||||
|
||||
# 张量声明
|
||||
now_tile = torch.LongTensor([self.start_tile], device=self.device).expand(B, -1)
|
||||
|
||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32, device=self.device)
|
||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes], device=self.device)
|
||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden, device=self.device)
|
||||
|
||||
# 条件编码,全局,所以只用一次
|
||||
cond = self.cond(val_cond)
|
||||
|
||||
for y in range(0, self.height):
|
||||
for x in range(0, self.width):
|
||||
x_tensor = torch.LongTensor([x], device=self.device)
|
||||
y_tensor = torch.LongTensor([y], device=self.device)
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(now_tile if use_self else target_map[:, y, x])
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
map_patch = self.map_patch(map, x, y)
|
||||
# 编码特征融合
|
||||
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
||||
# RNN 输出
|
||||
logits, h = self.rnn(feat, hidden)
|
||||
# 处理输出
|
||||
output_logits[:, y, x] = logits[:]
|
||||
hidden[:] = h[:]
|
||||
probs = F.softmax(logits, dim=1)
|
||||
tile_id = torch.argmax(probs, dim=1)
|
||||
map[:, y, x] = tile_id[:]
|
||||
now_tile[:] = tile_id[:]
|
||||
|
||||
return output_logits, map
|
||||
|
||||
# 拼接 tile + cond
|
||||
step_input = torch.cat([tile_emb, cond_expand], dim=-1)
|
||||
|
||||
out, _ = self.gru(step_input)
|
||||
logits = self.fc(out)
|
||||
return logits
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated('cuda:1') / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated('cuda:1') / 1024**2:.2f} MB")
|
||||
def print_memory(device, tag=""):
|
||||
if torch.cuda.is_available():
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB")
|
||||
else:
|
||||
print("当前设备不支持 cuda.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
input = torch.argmax(torch.rand(1, 32, 13 * 13).cuda(1), dim=1)
|
||||
cond = torch.rand(1, 256).cuda(1)
|
||||
device = torch.device("cpu")
|
||||
|
||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||
cond = torch.rand(1, 16).to(device)
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaRNN().cuda(1)
|
||||
model = GinkaRNNModel("cpu").to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
fake = model(input, cond)
|
||||
fake_logits, fake_map = model(cond, input, False)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={fake.shape}")
|
||||
print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}")
|
||||
print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
|
||||
print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}")
|
||||
print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}")
|
||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user