mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 增大 RNN 参数量
This commit is contained in:
parent
4d1179bcfb
commit
7ae4ae4572
@ -8,14 +8,13 @@ class RNNConditionEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# 条件编码
|
||||
|
||||
self.val_fc = nn.Sequential(
|
||||
nn.Linear(val_dim, output_dim),
|
||||
nn.LayerNorm(output_dim),
|
||||
nn.Linear(val_dim, output_dim * 2),
|
||||
nn.LayerNorm(output_dim * 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(output_dim, output_dim)
|
||||
nn.Linear(output_dim * 2, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, val_cond: torch.Tensor):
|
||||
@ -70,7 +69,7 @@ class GinkaMapPatch(nn.Module):
|
||||
return self.patch_cnn(result)
|
||||
|
||||
class GinkaTileEmbedding(nn.Module):
|
||||
def __init__(self, tile_classes=32, embed_dim=128):
|
||||
def __init__(self, tile_classes=32, embed_dim=256):
|
||||
super().__init__()
|
||||
|
||||
# 图块编码,上一次画的图块
|
||||
@ -81,7 +80,7 @@ class GinkaTileEmbedding(nn.Module):
|
||||
return self.embedding(tile)
|
||||
|
||||
class GinkaPosEmbedding(nn.Module):
|
||||
def __init__(self, width=13, height=13, embed_dim=128):
|
||||
def __init__(self, width=13, height=13, embed_dim=256):
|
||||
super().__init__()
|
||||
|
||||
# 位置编码
|
||||
@ -99,7 +98,7 @@ class GinkaPosEmbedding(nn.Module):
|
||||
return row, col
|
||||
|
||||
class GinkaInputFusion(nn.Module):
|
||||
def __init__(self, d_model=128):
|
||||
def __init__(self, d_model=256):
|
||||
super().__init__()
|
||||
|
||||
# 使用 Transformer 进行信息整合
|
||||
@ -116,19 +115,19 @@ class GinkaInputFusion(nn.Module):
|
||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
):
|
||||
"""
|
||||
tile_embed: [B, 128]
|
||||
tile_embed: [B, 256]
|
||||
cond_vec: [B, 256]
|
||||
row_embed: [B, 128]
|
||||
col_embed: [B, 128]
|
||||
row_embed: [B, 256]
|
||||
col_embed: [B, 256]
|
||||
patch_vec: [B, 512]
|
||||
"""
|
||||
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||
vec = torch.stack(torch.split(vec, 128, dim=1), dim=1)
|
||||
vec = torch.stack(torch.split(vec, 256, dim=1), dim=1)
|
||||
feat = self.transformer(vec)
|
||||
return torch.mean(feat, dim=1)
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, input_dim=128, hidden_dim=1024):
|
||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
|
||||
super().__init__()
|
||||
|
||||
# GRU
|
||||
@ -153,16 +152,16 @@ class GinkaRNNModel(nn.Module):
|
||||
self.height = height
|
||||
self.start_tile = start_tile
|
||||
|
||||
self.rnn_hidden = 1024
|
||||
self.rnn_hidden = 2048
|
||||
self.tile_classes = 32
|
||||
|
||||
# 模型结构
|
||||
self.cond = RNNConditionEncoder()
|
||||
self.tile_embedding = GinkaTileEmbedding()
|
||||
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
|
||||
self.pos_embedding = GinkaPosEmbedding()
|
||||
self.map_patch = GinkaMapPatch()
|
||||
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
|
||||
self.feat_fusion = GinkaInputFusion()
|
||||
self.rnn = GinkaRNN(hidden_dim=self.rnn_hidden)
|
||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||
|
||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user