chore: 微调模型

This commit is contained in:
unanmed 2025-12-14 12:06:53 +08:00
parent adcaa55432
commit 616b7fd39c

View File

@ -35,13 +35,13 @@ class GinkaMapPatch(nn.Module):
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 3, padding=1),
nn.Conv2d(256, 512, 3),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.AvgPool2d(kernel_size=(5, 5)),
nn.Flatten()
)
self.fc = nn.Linear(512 * 3 * 3, 256)
def forward(self, map: torch.Tensor, x: int, y: int):
"""
@ -66,7 +66,9 @@ class GinkaMapPatch(nn.Module):
result[:, 4, 4] = 0
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
return self.patch_cnn(result)
feat = self.patch_cnn(result)
feat = self.fc(feat)
return feat
class GinkaTileEmbedding(nn.Module):
def __init__(self, tile_classes=32, embed_dim=256):
@ -119,12 +121,11 @@ class GinkaInputFusion(nn.Module):
cond_vec: [B, 256]
row_embed: [B, 256]
col_embed: [B, 256]
patch_vec: [B, 512]
patch_vec: [B, 256]
"""
vec = torch.cat([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
vec = torch.stack(torch.split(vec, 256, dim=1), dim=1)
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
feat = self.transformer(vec)
return torch.mean(feat, dim=1)
return feat[:, 0]
class GinkaRNN(nn.Module):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=2048):
@ -141,7 +142,7 @@ class GinkaRNN(nn.Module):
"""
hidden = self.gru(feat_fusion, hidden)
logits = self.fc(hidden)
return F.sigmoid(logits), hidden
return logits, hidden
class GinkaRNNModel(nn.Module):
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
@ -152,7 +153,7 @@ class GinkaRNNModel(nn.Module):
self.height = height
self.start_tile = start_tile
self.rnn_hidden = 2048
self.rnn_hidden = 512
self.tile_classes = 32
# 模型结构
@ -196,8 +197,7 @@ class GinkaRNNModel(nn.Module):
# 处理输出
output_logits[:, y, x] = logits[:]
hidden = h
probs = F.softmax(logits, dim=1)
tile_id = torch.argmax(probs, dim=1).detach()
tile_id = torch.argmax(logits, dim=1).detach()
map[:, y, x] = tile_id[:]
now_tile = tile_id if use_self else target_map[:, y, x].detach()