mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 微调模型
This commit is contained in:
parent
adcaa55432
commit
616b7fd39c
@ -35,13 +35,13 @@ class GinkaMapPatch(nn.Module):
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(256, 512, 3, padding=1),
|
||||
nn.Conv2d(256, 512, 3),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.AvgPool2d(kernel_size=(5, 5)),
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc = nn.Linear(512 * 3 * 3, 256)
|
||||
|
||||
def forward(self, map: torch.Tensor, x: int, y: int):
|
||||
"""
|
||||
@ -66,7 +66,9 @@ class GinkaMapPatch(nn.Module):
|
||||
result[:, 4, 4] = 0
|
||||
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||
|
||||
return self.patch_cnn(result)
|
||||
feat = self.patch_cnn(result)
|
||||
feat = self.fc(feat)
|
||||
return feat
|
||||
|
||||
class GinkaTileEmbedding(nn.Module):
|
||||
def __init__(self, tile_classes=32, embed_dim=256):
|
||||
@ -119,12 +121,11 @@ class GinkaInputFusion(nn.Module):
|
||||
cond_vec: [B, 256]
|
||||
row_embed: [B, 256]
|
||||
col_embed: [B, 256]
|
||||
patch_vec: [B, 512]
|
||||
patch_vec: [B, 256]
|
||||
"""
|
||||
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||
vec = torch.stack(torch.split(vec, 256, dim=1), dim=1)
|
||||
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||
feat = self.transformer(vec)
|
||||
return torch.mean(feat, dim=1)
|
||||
return feat[:, 0]
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
|
||||
@ -141,7 +142,7 @@ class GinkaRNN(nn.Module):
|
||||
"""
|
||||
hidden = self.gru(feat_fusion, hidden)
|
||||
logits = self.fc(hidden)
|
||||
return F.sigmoid(logits), hidden
|
||||
return logits, hidden
|
||||
|
||||
class GinkaRNNModel(nn.Module):
|
||||
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
|
||||
@ -152,7 +153,7 @@ class GinkaRNNModel(nn.Module):
|
||||
self.height = height
|
||||
self.start_tile = start_tile
|
||||
|
||||
self.rnn_hidden = 2048
|
||||
self.rnn_hidden = 512
|
||||
self.tile_classes = 32
|
||||
|
||||
# 模型结构
|
||||
@ -196,8 +197,7 @@ class GinkaRNNModel(nn.Module):
|
||||
# 处理输出
|
||||
output_logits[:, y, x] = logits[:]
|
||||
hidden = h
|
||||
probs = F.softmax(logits, dim=1)
|
||||
tile_id = torch.argmax(probs, dim=1).detach()
|
||||
tile_id = torch.argmax(logits, dim=1).detach()
|
||||
map[:, y, x] = tile_id[:]
|
||||
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user