perf: 优化网络参数

This commit is contained in:
unanmed 2025-03-16 20:43:37 +08:00
parent 98f7a9cdcf
commit 0910bddba2
3 changed files with 29 additions and 11 deletions

View File

@ -1,7 +1,7 @@
import torch.nn as nn import torch.nn as nn
class MinamoLoss(nn.Module): class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0, topo_weight=1): def __init__(self, vision_weight=0.4, topo_weight=0.6):
super().__init__() super().__init__()
self.vision_weight = vision_weight self.vision_weight = vision_weight
self.topo_weight = topo_weight self.topo_weight = topo_weight

View File

@ -45,7 +45,7 @@ class DirectionalAttention(nn.Module):
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1) return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
class MinamoModel(nn.Module): class MinamoModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256): def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32):
super().__init__() super().__init__()
# 嵌入层处理不同图块类型 # 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim) self.embedding = nn.Embedding(tile_types, embedding_dim)
@ -57,25 +57,43 @@ class MinamoModel(nn.Module):
nn.ReLU(), nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
DualAttention(conv_channels*2), DualAttention(conv_channels*2),
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
DualAttention(conv_channels*4),
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1),
DualAttention(conv_channels*8),
nn.BatchNorm2d(conv_channels*8),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1) nn.AdaptiveAvgPool2d(1)
) )
# 拓扑特征分支 # 拓扑特征分支
self.topo_conv = nn.Sequential( self.topo_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构 nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
nn.MaxPool2d(2), nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
# nn.MaxPool2d(2),
# GraphConvLayer(128, 256), # 图卷积层 # GraphConvLayer(128, 256), # 图卷积层
nn.AdaptiveMaxPool2d(1) nn.AdaptiveMaxPool2d(1)
) )
# 多任务预测头 # 多任务预测头
self.vision_head = nn.Sequential( self.vision_head = nn.Sequential(
nn.Linear(conv_channels*2, 1), nn.Linear(conv_channels*8, 1),
nn.Sigmoid() nn.Sigmoid()
) )
self.topo_head = nn.Sequential( self.topo_head = nn.Sequential(
nn.Linear(conv_channels, 1), nn.Linear(conv_channels*4, 1),
nn.Sigmoid() nn.Sigmoid()
) )

View File

@ -33,21 +33,21 @@ def collate_fn(batch):
) )
def train(): def train():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
model = MinamoModel(32) model = MinamoModel(32)
model.to(device) model.to(device)
# 准备数据集 # 准备数据集
dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-dataset.json") dataset = MinamoDataset("minamo-dataset.json")
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") val_dataset = MinamoDataset("minamo-eval.json")
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=32, batch_size=64,
shuffle=True shuffle=True
) )
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=32, batch_size=64,
shuffle=True shuffle=True
) )
@ -98,7 +98,7 @@ def train():
scheduler.step() scheduler.step()
# 每十轮推理一次验证集 # 每十轮推理一次验证集
if (epoch + 1) % 10 == 0: if (epoch + 1) % 5 == 0:
model.eval() model.eval()
val_loss = 0 val_loss = 0
with torch.no_grad(): with torch.no_grad():