From c79662089b895a435058900a648efc10ff178d0d Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 13 Dec 2025 13:03:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20RNN=20=E8=9E=8D=E5=90=88=E7=BC=96?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/rnn.py | 234 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 207 insertions(+), 27 deletions(-) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index d24507e..510b01f 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -3,53 +3,233 @@ import torch import torch.nn as nn import torch.nn.functional as F -class GinkaRNN(nn.Module): - def __init__(self, tile_classes=32, cond_dim=256, input_dim=256, hidden_dim=1024, num_layers=2): +class RNNConditionEncoder(nn.Module): + def __init__(self, val_dim=16, output_dim=256, width=13, height=13): super().__init__() - # 输入部分 - self.embedding = nn.Embedding(tile_classes, input_dim) - self.input_fc = nn.Linear(input_dim, input_dim) + # 条件编码 - self.gru = nn.GRU(input_dim + cond_dim, hidden_dim, num_layers, batch_first=True) + self.val_fc = nn.Sequential( + nn.Linear(val_dim, output_dim), + nn.LayerNorm(output_dim), + nn.ReLU(), + ) + self.fusion = nn.Sequential( + nn.Linear(output_dim, output_dim) + ) + + def forward(self, val_cond: torch.Tensor): + val_hidden = self.val_fc(val_cond) + return self.fusion(val_hidden) + +class GinkaMapPatch(nn.Module): + def __init__(self, tile_classes=32, width=13, height=13): + super().__init__() + + # 地图局部卷积,用于捕获局部结构信息 + + self.width = width + self.height = height + + self.patch_cnn = nn.Sequential( + nn.Conv2d(tile_classes, 256, 3, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(), + + nn.Conv2d(256, 512, 3, padding=1), + nn.BatchNorm2d(512), + nn.ReLU(), + + nn.AvgPool2d(kernel_size=(5, 5)), + nn.Flatten() + ) + + def forward(self, map: torch.Tensor, x: int, y: int): + """ + map: [B, H, W] + """ + B, H, W = map.shape + result = torch.zeros([B, 5, 5], dtype=torch.long, device=map.device) + left = x - 2 if x >= 2 else 0 + right = x + 3 if x < self.width - 2 else self.width + top = y - 4 if y >= 4 else 0 + bottom = y + 1 + + res_left = left - (x - 2) + res_right = right - (x + 3) + 5 + res_top = top - (y - 4) + res_bottom = 5 + + result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right] + result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() + + return self.patch_cnn(result) + +class GinkaTileEmbedding(nn.Module): + def __init__(self, tile_classes=32, embed_dim=128): + super().__init__() + + # 图块编码,上一次画的图块 + + self.embedding = nn.Embedding(tile_classes, embed_dim) + + def forward(self, tile: torch.Tensor): + return self.embedding(tile) + +class GinkaPosEmbedding(nn.Module): + def __init__(self, width=13, height=13, embed_dim=128): + super().__init__() + + # 位置编码 + + self.width = width + self.height = height + + self.row_embedding = nn.Embedding(width, embed_dim) + self.col_embedding = nn.Embedding(height, embed_dim) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + row = self.row_embedding(x) + col = self.col_embedding(y) + + return row, col + +class GinkaInputFusion(nn.Module): + def __init__(self, d_model=128): + super().__init__() + + # 使用 Transformer 进行信息整合 + + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True + ), + num_layers=4 + ) + + def forward( + self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, + row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor + ): + """ + tile_embed: [B, 128] + cond_vec: [B, 256] + row_embed: [B, 128] + col_embed: [B, 128] + patch_vec: [B, 512] + """ + vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) + vec = torch.stack(torch.split(vec, 128, dim=1), dim=1) + feat = self.transformer(vec) + return torch.mean(feat, dim=1) + +class GinkaRNN(nn.Module): + def __init__(self, tile_classes=32, input_dim=128, hidden_dim=1024): + super().__init__() + + # GRU + self.gru = nn.GRUCell(input_dim, hidden_dim) self.fc = nn.Linear(hidden_dim, tile_classes) - def forward(self, x: torch.Tensor, cond: torch.Tensor): + def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor): """ - x: [B, T] - cond: [B, cond_dim] + feat_fusion: [B, input_dim] + hidden: [B, hidden_dim] """ - B, T = x.shape - tile_emb = self.input_fc(self.embedding(x)) # [B, T, input_dim] - cond_expand = cond.unsqueeze(1).expand(B, T, cond.shape[-1]) # [B, T, cond_dim] + hidden = self.gru(feat_fusion, hidden) + logits = self.fc(hidden) + return logits, hidden + +class GinkaRNNModel(nn.Module): + def __init__(self, device: torch.device, start_tile=31, width=13, height=13): + super().__init__() + + self.device = device + self.width = width + self.height = height + self.start_tile = start_tile + + self.rnn_hidden = 1024 + self.tile_classes = 32 + + # 模型结构 + self.cond = RNNConditionEncoder() + self.tile_embedding = GinkaTileEmbedding() + self.pos_embedding = GinkaPosEmbedding() + self.map_patch = GinkaMapPatch() + self.feat_fusion = GinkaInputFusion() + self.rnn = GinkaRNN(hidden_dim=self.rnn_hidden) + + def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False): + """ + val_cond: [B, val_dim] + target_map: [B, H, W] + use_self: 是否使用自己生成的上一步结果执行下一步 + """ + B, C = val_cond.shape + + # 张量声明 + now_tile = torch.LongTensor([self.start_tile], device=self.device).expand(B, -1) + + map = torch.zeros([B, self.height, self.width], dtype=torch.int32, device=self.device) + output_logits = torch.zeros([B, self.height, self.width, self.tile_classes], device=self.device) + hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden, device=self.device) + + # 条件编码,全局,所以只用一次 + cond = self.cond(val_cond) + + for y in range(0, self.height): + for x in range(0, self.width): + x_tensor = torch.LongTensor([x], device=self.device) + y_tensor = torch.LongTensor([y], device=self.device) + # 位置编码、图块编码、地图局部编码 + tile_embed = self.tile_embedding(now_tile if use_self else target_map[:, y, x]) + row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) + map_patch = self.map_patch(map, x, y) + # 编码特征融合 + feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch) + # RNN 输出 + logits, h = self.rnn(feat, hidden) + # 处理输出 + output_logits[:, y, x] = logits[:] + hidden[:] = h[:] + probs = F.softmax(logits, dim=1) + tile_id = torch.argmax(probs, dim=1) + map[:, y, x] = tile_id[:] + now_tile[:] = tile_id[:] + + return output_logits, map - # 拼接 tile + cond - step_input = torch.cat([tile_emb, cond_expand], dim=-1) - - out, _ = self.gru(step_input) - logits = self.fc(out) - return logits - -def print_memory(tag=""): - print(f"{tag} | 当前显存: {torch.cuda.memory_allocated('cuda:1') / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated('cuda:1') / 1024**2:.2f} MB") +def print_memory(device, tag=""): + if torch.cuda.is_available(): + print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB") + else: + print("当前设备不支持 cuda.") if __name__ == "__main__": - input = torch.argmax(torch.rand(1, 32, 13 * 13).cuda(1), dim=1) - cond = torch.rand(1, 256).cuda(1) + device = torch.device("cpu") + + input = torch.randint(0, 32, [1, 13, 13]).to(device) + cond = torch.rand(1, 16).to(device) # 初始化模型 - model = GinkaRNN().cuda(1) + model = GinkaRNNModel("cpu").to(device) print_memory("初始化后") # 前向传播 start = time.perf_counter() - fake = model(input, cond) + fake_logits, fake_map = model(cond, input, False) end = time.perf_counter() print_memory("前向传播后") print(f"推理耗时: {end - start}") - print(f"输入形状: feat={input.shape}") - print(f"输出形状: output={fake.shape}") + print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}") + print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond.parameters())}") + print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") + print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") + print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}") + print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}") + print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")