perf: 优化网络参数

This commit is contained in:
unanmed 2025-03-16 20:43:37 +08:00
parent 98f7a9cdcf
commit 0910bddba2
3 changed files with 29 additions and 11 deletions

View File

@ -1,7 +1,7 @@
import torch.nn as nn
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0, topo_weight=1):
def __init__(self, vision_weight=0.4, topo_weight=0.6):
super().__init__()
self.vision_weight = vision_weight
self.topo_weight = topo_weight

View File

@ -45,7 +45,7 @@ class DirectionalAttention(nn.Module):
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
class MinamoModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256):
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32):
super().__init__()
# 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim)
@ -57,25 +57,43 @@ class MinamoModel(nn.Module):
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
DualAttention(conv_channels*2),
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
DualAttention(conv_channels*4),
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1),
DualAttention(conv_channels*8),
nn.BatchNorm2d(conv_channels*8),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
# 拓扑特征分支
self.topo_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
nn.MaxPool2d(2),
nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
# nn.MaxPool2d(2),
# GraphConvLayer(128, 256), # 图卷积层
nn.AdaptiveMaxPool2d(1)
)
# 多任务预测头
self.vision_head = nn.Sequential(
nn.Linear(conv_channels*2, 1),
nn.Linear(conv_channels*8, 1),
nn.Sigmoid()
)
self.topo_head = nn.Sequential(
nn.Linear(conv_channels, 1),
nn.Linear(conv_channels*4, 1),
nn.Sigmoid()
)

View File

@ -33,21 +33,21 @@ def collate_fn(batch):
)
def train():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
model = MinamoModel(32)
model.to(device)
# 准备数据集
dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-dataset.json")
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
dataset = MinamoDataset("minamo-dataset.json")
val_dataset = MinamoDataset("minamo-eval.json")
dataloader = DataLoader(
dataset,
batch_size=32,
batch_size=64,
shuffle=True
)
val_loader = DataLoader(
val_dataset,
batch_size=32,
batch_size=64,
shuffle=True
)
@ -98,7 +98,7 @@ def train():
scheduler.step()
# 每十轮推理一次验证集
if (epoch + 1) % 10 == 0:
if (epoch + 1) % 5 == 0:
model.eval()
val_loss = 0
with torch.no_grad():