diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index c899f1e..7499545 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -35,13 +35,13 @@ class GinkaMapPatch(nn.Module): nn.BatchNorm2d(256), nn.ReLU(), - nn.Conv2d(256, 512, 3, padding=1), + nn.Conv2d(256, 512, 3), nn.BatchNorm2d(512), nn.ReLU(), - nn.AvgPool2d(kernel_size=(5, 5)), nn.Flatten() ) + self.fc = nn.Linear(512 * 3 * 3, 256) def forward(self, map: torch.Tensor, x: int, y: int): """ @@ -66,7 +66,9 @@ class GinkaMapPatch(nn.Module): result[:, 4, 4] = 0 result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() - return self.patch_cnn(result) + feat = self.patch_cnn(result) + feat = self.fc(feat) + return feat class GinkaTileEmbedding(nn.Module): def __init__(self, tile_classes=32, embed_dim=256): @@ -119,12 +121,11 @@ class GinkaInputFusion(nn.Module): cond_vec: [B, 256] row_embed: [B, 256] col_embed: [B, 256] - patch_vec: [B, 512] + patch_vec: [B, 256] """ - vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) - vec = torch.stack(torch.split(vec, 256, dim=1), dim=1) + vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) feat = self.transformer(vec) - return torch.mean(feat, dim=1) + return feat[:, 0] class GinkaRNN(nn.Module): def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048): @@ -141,7 +142,7 @@ class GinkaRNN(nn.Module): """ hidden = self.gru(feat_fusion, hidden) logits = self.fc(hidden) - return F.sigmoid(logits), hidden + return logits, hidden class GinkaRNNModel(nn.Module): def __init__(self, device: torch.device, start_tile=31, width=13, height=13): @@ -152,7 +153,7 @@ class GinkaRNNModel(nn.Module): self.height = height self.start_tile = start_tile - self.rnn_hidden = 2048 + self.rnn_hidden = 512 self.tile_classes = 32 # 模型结构 @@ -196,8 +197,7 @@ class GinkaRNNModel(nn.Module): # 处理输出 output_logits[:, y, x] = logits[:] hidden = h - probs = F.softmax(logits, dim=1) - tile_id = torch.argmax(probs, dim=1).detach() + tile_id = torch.argmax(logits, dim=1).detach() map[:, y, x] = tile_id[:] now_tile = tile_id if use_self else target_map[:, y, x].detach()