From bf5160edacb962a4176ce3b1b44739319e9a53a1 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 17 Dec 2025 13:09:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9C=B0=E5=9B=BE?= =?UTF-8?q?=E8=92=99=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/rnn.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 6c5fd1e..654cae2 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -9,12 +9,12 @@ class RNNConditionEncoder(nn.Module): # 条件编码 self.val_fc = nn.Sequential( - nn.Linear(val_dim, output_dim * 2), - nn.LayerNorm(output_dim * 2), - nn.ReLU(), + nn.Linear(val_dim, output_dim * 4), + nn.LayerNorm(output_dim * 4), + nn.GELU(), ) self.fusion = nn.Sequential( - nn.Linear(output_dim * 2, output_dim) + nn.Linear(output_dim * 4, output_dim) ) def forward(self, val_cond: torch.Tensor): @@ -29,17 +29,18 @@ class GinkaMapPatch(nn.Module): self.width = width self.height = height + self.tile_classes = 32 self.patch_cnn = nn.Sequential( - nn.Conv2d(tile_classes, 256, 3, padding=1), + nn.Conv2d(tile_classes + 1, 256, 3, padding=1), nn.Dropout(0.2), nn.BatchNorm2d(256), - nn.ReLU(), + nn.GELU(), nn.Conv2d(256, 512, 3), nn.Dropout(0.2), nn.BatchNorm2d(512), - nn.ReLU(), + nn.GELU(), nn.Flatten() ) @@ -50,6 +51,7 @@ class GinkaMapPatch(nn.Module): map: [B, H, W] """ B, H, W = map.shape + mask = torch.zeros([B, 5, 5]).to(map.device) result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device) left = x - 2 if x >= 2 else 0 right = x + 3 if x < self.width - 2 else self.width @@ -66,9 +68,15 @@ class GinkaMapPatch(nn.Module): result[:, 4, 2] = 0 result[:, 4, 3] = 0 result[:, 4, 4] = 0 - result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() + mask[:, res_top:res_bottom, res_left:res_right] = 1 + mask[:, 4, 2] = 0 + mask[:, 4, 3] = 0 + mask[:, 4, 4] = 0 + masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]) + masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() + masked_result[:, 32] = mask - feat = self.patch_cnn(result) + feat = self.patch_cnn(masked_result) feat = self.fc(feat) return feat @@ -137,7 +145,13 @@ class GinkaRNN(nn.Module): # GRU self.gru = nn.GRUCell(input_dim, hidden_dim) self.drop = nn.Dropout(0.2) - self.fc = nn.Linear(hidden_dim, tile_classes) + self.fc = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), + + nn.Linear(hidden_dim, tile_classes) + ) def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor): """