chore: 增大 RNN 参数量

This commit is contained in:
unanmed 2025-12-13 21:34:59 +08:00
parent 4d1179bcfb
commit 7ae4ae4572

View File

@ -8,14 +8,13 @@ class RNNConditionEncoder(nn.Module):
super().__init__()
# 条件编码
self.val_fc = nn.Sequential(
nn.Linear(val_dim, output_dim),
nn.LayerNorm(output_dim),
nn.Linear(val_dim, output_dim * 2),
nn.LayerNorm(output_dim * 2),
nn.ReLU(),
)
self.fusion = nn.Sequential(
nn.Linear(output_dim, output_dim)
nn.Linear(output_dim * 2, output_dim)
)
def forward(self, val_cond: torch.Tensor):
@ -70,7 +69,7 @@ class GinkaMapPatch(nn.Module):
return self.patch_cnn(result)
class GinkaTileEmbedding(nn.Module):
def __init__(self, tile_classes=32, embed_dim=128):
def __init__(self, tile_classes=32, embed_dim=256):
super().__init__()
# 图块编码,上一次画的图块
@ -81,7 +80,7 @@ class GinkaTileEmbedding(nn.Module):
return self.embedding(tile)
class GinkaPosEmbedding(nn.Module):
def __init__(self, width=13, height=13, embed_dim=128):
def __init__(self, width=13, height=13, embed_dim=256):
super().__init__()
# 位置编码
@ -99,7 +98,7 @@ class GinkaPosEmbedding(nn.Module):
return row, col
class GinkaInputFusion(nn.Module):
def __init__(self, d_model=128):
def __init__(self, d_model=256):
super().__init__()
# 使用 Transformer 进行信息整合
@ -116,19 +115,19 @@ class GinkaInputFusion(nn.Module):
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
):
"""
tile_embed: [B, 128]
tile_embed: [B, 256]
cond_vec: [B, 256]
row_embed: [B, 128]
col_embed: [B, 128]
row_embed: [B, 256]
col_embed: [B, 256]
patch_vec: [B, 512]
"""
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
vec = torch.stack(torch.split(vec, 128, dim=1), dim=1)
vec = torch.stack(torch.split(vec, 256, dim=1), dim=1)
feat = self.transformer(vec)
return torch.mean(feat, dim=1)
class GinkaRNN(nn.Module):
def __init__(self, tile_classes=32, input_dim=128, hidden_dim=1024):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
super().__init__()
# GRU
@ -153,16 +152,16 @@ class GinkaRNNModel(nn.Module):
self.height = height
self.start_tile = start_tile
self.rnn_hidden = 1024
self.rnn_hidden = 2048
self.tile_classes = 32
# 模型结构
self.cond = RNNConditionEncoder()
self.tile_embedding = GinkaTileEmbedding()
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
self.pos_embedding = GinkaPosEmbedding()
self.map_patch = GinkaMapPatch()
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
self.feat_fusion = GinkaInputFusion()
self.rnn = GinkaRNN(hidden_dim=self.rnn_hidden)
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
"""