mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 增大 RNN 参数量
This commit is contained in:
parent
4d1179bcfb
commit
7ae4ae4572
@ -8,14 +8,13 @@ class RNNConditionEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 条件编码
|
# 条件编码
|
||||||
|
|
||||||
self.val_fc = nn.Sequential(
|
self.val_fc = nn.Sequential(
|
||||||
nn.Linear(val_dim, output_dim),
|
nn.Linear(val_dim, output_dim * 2),
|
||||||
nn.LayerNorm(output_dim),
|
nn.LayerNorm(output_dim * 2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
self.fusion = nn.Sequential(
|
self.fusion = nn.Sequential(
|
||||||
nn.Linear(output_dim, output_dim)
|
nn.Linear(output_dim * 2, output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, val_cond: torch.Tensor):
|
def forward(self, val_cond: torch.Tensor):
|
||||||
@ -70,7 +69,7 @@ class GinkaMapPatch(nn.Module):
|
|||||||
return self.patch_cnn(result)
|
return self.patch_cnn(result)
|
||||||
|
|
||||||
class GinkaTileEmbedding(nn.Module):
|
class GinkaTileEmbedding(nn.Module):
|
||||||
def __init__(self, tile_classes=32, embed_dim=128):
|
def __init__(self, tile_classes=32, embed_dim=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 图块编码,上一次画的图块
|
# 图块编码,上一次画的图块
|
||||||
@ -81,7 +80,7 @@ class GinkaTileEmbedding(nn.Module):
|
|||||||
return self.embedding(tile)
|
return self.embedding(tile)
|
||||||
|
|
||||||
class GinkaPosEmbedding(nn.Module):
|
class GinkaPosEmbedding(nn.Module):
|
||||||
def __init__(self, width=13, height=13, embed_dim=128):
|
def __init__(self, width=13, height=13, embed_dim=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 位置编码
|
# 位置编码
|
||||||
@ -99,7 +98,7 @@ class GinkaPosEmbedding(nn.Module):
|
|||||||
return row, col
|
return row, col
|
||||||
|
|
||||||
class GinkaInputFusion(nn.Module):
|
class GinkaInputFusion(nn.Module):
|
||||||
def __init__(self, d_model=128):
|
def __init__(self, d_model=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 使用 Transformer 进行信息整合
|
# 使用 Transformer 进行信息整合
|
||||||
@ -116,19 +115,19 @@ class GinkaInputFusion(nn.Module):
|
|||||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
tile_embed: [B, 128]
|
tile_embed: [B, 256]
|
||||||
cond_vec: [B, 256]
|
cond_vec: [B, 256]
|
||||||
row_embed: [B, 128]
|
row_embed: [B, 256]
|
||||||
col_embed: [B, 128]
|
col_embed: [B, 256]
|
||||||
patch_vec: [B, 512]
|
patch_vec: [B, 512]
|
||||||
"""
|
"""
|
||||||
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||||
vec = torch.stack(torch.split(vec, 128, dim=1), dim=1)
|
vec = torch.stack(torch.split(vec, 256, dim=1), dim=1)
|
||||||
feat = self.transformer(vec)
|
feat = self.transformer(vec)
|
||||||
return torch.mean(feat, dim=1)
|
return torch.mean(feat, dim=1)
|
||||||
|
|
||||||
class GinkaRNN(nn.Module):
|
class GinkaRNN(nn.Module):
|
||||||
def __init__(self, tile_classes=32, input_dim=128, hidden_dim=1024):
|
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# GRU
|
# GRU
|
||||||
@ -153,16 +152,16 @@ class GinkaRNNModel(nn.Module):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.start_tile = start_tile
|
self.start_tile = start_tile
|
||||||
|
|
||||||
self.rnn_hidden = 1024
|
self.rnn_hidden = 2048
|
||||||
self.tile_classes = 32
|
self.tile_classes = 32
|
||||||
|
|
||||||
# 模型结构
|
# 模型结构
|
||||||
self.cond = RNNConditionEncoder()
|
self.cond = RNNConditionEncoder()
|
||||||
self.tile_embedding = GinkaTileEmbedding()
|
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
|
||||||
self.pos_embedding = GinkaPosEmbedding()
|
self.pos_embedding = GinkaPosEmbedding()
|
||||||
self.map_patch = GinkaMapPatch()
|
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
|
||||||
self.feat_fusion = GinkaInputFusion()
|
self.feat_fusion = GinkaInputFusion()
|
||||||
self.rnn = GinkaRNN(hidden_dim=self.rnn_hidden)
|
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||||
|
|
||||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
|
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user