From 7ae4ae457258950de0f7f4bd4e2cd4e3597b512f Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 13 Dec 2025 21:34:59 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=A2=9E=E5=A4=A7=20RNN=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/rnn.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index ef83f52..c899f1e 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -8,14 +8,13 @@ class RNNConditionEncoder(nn.Module): super().__init__() # 条件编码 - self.val_fc = nn.Sequential( - nn.Linear(val_dim, output_dim), - nn.LayerNorm(output_dim), + nn.Linear(val_dim, output_dim * 2), + nn.LayerNorm(output_dim * 2), nn.ReLU(), ) self.fusion = nn.Sequential( - nn.Linear(output_dim, output_dim) + nn.Linear(output_dim * 2, output_dim) ) def forward(self, val_cond: torch.Tensor): @@ -70,7 +69,7 @@ class GinkaMapPatch(nn.Module): return self.patch_cnn(result) class GinkaTileEmbedding(nn.Module): - def __init__(self, tile_classes=32, embed_dim=128): + def __init__(self, tile_classes=32, embed_dim=256): super().__init__() # 图块编码,上一次画的图块 @@ -81,7 +80,7 @@ class GinkaTileEmbedding(nn.Module): return self.embedding(tile) class GinkaPosEmbedding(nn.Module): - def __init__(self, width=13, height=13, embed_dim=128): + def __init__(self, width=13, height=13, embed_dim=256): super().__init__() # 位置编码 @@ -99,7 +98,7 @@ class GinkaPosEmbedding(nn.Module): return row, col class GinkaInputFusion(nn.Module): - def __init__(self, d_model=128): + def __init__(self, d_model=256): super().__init__() # 使用 Transformer 进行信息整合 @@ -116,19 +115,19 @@ class GinkaInputFusion(nn.Module): row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor ): """ - tile_embed: [B, 128] + tile_embed: [B, 256] cond_vec: [B, 256] - row_embed: [B, 128] - col_embed: [B, 128] + row_embed: [B, 256] + col_embed: [B, 256] patch_vec: [B, 512] """ vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) - vec = torch.stack(torch.split(vec, 128, dim=1), dim=1) + vec = torch.stack(torch.split(vec, 256, dim=1), dim=1) feat = self.transformer(vec) return torch.mean(feat, dim=1) class GinkaRNN(nn.Module): - def __init__(self, tile_classes=32, input_dim=128, hidden_dim=1024): + def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048): super().__init__() # GRU @@ -153,16 +152,16 @@ class GinkaRNNModel(nn.Module): self.height = height self.start_tile = start_tile - self.rnn_hidden = 1024 + self.rnn_hidden = 2048 self.tile_classes = 32 # 模型结构 self.cond = RNNConditionEncoder() - self.tile_embedding = GinkaTileEmbedding() + self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes) self.pos_embedding = GinkaPosEmbedding() - self.map_patch = GinkaMapPatch() + self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes) self.feat_fusion = GinkaInputFusion() - self.rnn = GinkaRNN(hidden_dim=self.rnn_hidden) + self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False): """