mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 添加地图蒙版
This commit is contained in:
parent
1962c7a712
commit
bf5160edac
@ -9,12 +9,12 @@ class RNNConditionEncoder(nn.Module):
|
||||
|
||||
# 条件编码
|
||||
self.val_fc = nn.Sequential(
|
||||
nn.Linear(val_dim, output_dim * 2),
|
||||
nn.LayerNorm(output_dim * 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(val_dim, output_dim * 4),
|
||||
nn.LayerNorm(output_dim * 4),
|
||||
nn.GELU(),
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(output_dim * 2, output_dim)
|
||||
nn.Linear(output_dim * 4, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, val_cond: torch.Tensor):
|
||||
@ -29,17 +29,18 @@ class GinkaMapPatch(nn.Module):
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.tile_classes = 32
|
||||
|
||||
self.patch_cnn = nn.Sequential(
|
||||
nn.Conv2d(tile_classes, 256, 3, padding=1),
|
||||
nn.Conv2d(tile_classes + 1, 256, 3, padding=1),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(256, 512, 3),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Flatten()
|
||||
)
|
||||
@ -50,6 +51,7 @@ class GinkaMapPatch(nn.Module):
|
||||
map: [B, H, W]
|
||||
"""
|
||||
B, H, W = map.shape
|
||||
mask = torch.zeros([B, 5, 5]).to(map.device)
|
||||
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
|
||||
left = x - 2 if x >= 2 else 0
|
||||
right = x + 3 if x < self.width - 2 else self.width
|
||||
@ -66,9 +68,15 @@ class GinkaMapPatch(nn.Module):
|
||||
result[:, 4, 2] = 0
|
||||
result[:, 4, 3] = 0
|
||||
result[:, 4, 4] = 0
|
||||
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||
mask[:, res_top:res_bottom, res_left:res_right] = 1
|
||||
mask[:, 4, 2] = 0
|
||||
mask[:, 4, 3] = 0
|
||||
mask[:, 4, 4] = 0
|
||||
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5])
|
||||
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||
masked_result[:, 32] = mask
|
||||
|
||||
feat = self.patch_cnn(result)
|
||||
feat = self.patch_cnn(masked_result)
|
||||
feat = self.fc(feat)
|
||||
return feat
|
||||
|
||||
@ -137,7 +145,13 @@ class GinkaRNN(nn.Module):
|
||||
# GRU
|
||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
||||
self.drop = nn.Dropout(0.2)
|
||||
self.fc = nn.Linear(hidden_dim, tile_classes)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(hidden_dim, tile_classes)
|
||||
)
|
||||
|
||||
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user