From 0910bddba2c6a8af1d2e46480494c34e7e9561e6 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 16 Mar 2025 20:43:37 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E7=BD=91=E7=BB=9C?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/model/loss.py | 2 +- minamo/model/model.py | 26 ++++++++++++++++++++++---- minamo/train.py | 12 ++++++------ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 4ada733..6c292b8 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0, topo_weight=1): + def __init__(self, vision_weight=0.4, topo_weight=0.6): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/model.py b/minamo/model/model.py index 1ea8ad0..84797bb 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -45,7 +45,7 @@ class DirectionalAttention(nn.Module): return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1) class MinamoModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256): + def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32): super().__init__() # 嵌入层处理不同图块类型 self.embedding = nn.Embedding(tile_types, embedding_dim) @@ -57,25 +57,43 @@ class MinamoModel(nn.Module): nn.ReLU(), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), DualAttention(conv_channels*2), + nn.BatchNorm2d(conv_channels*2), + nn.ReLU(), + nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), + DualAttention(conv_channels*4), + nn.BatchNorm2d(conv_channels*4), + nn.ReLU(), + nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1), + DualAttention(conv_channels*8), + nn.BatchNorm2d(conv_channels*8), + nn.ReLU(), nn.AdaptiveAvgPool2d(1) ) # 拓扑特征分支 self.topo_conv = nn.Sequential( nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构 - nn.MaxPool2d(2), + nn.BatchNorm2d(conv_channels), + nn.ReLU(), + nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构 + nn.BatchNorm2d(conv_channels*2), + nn.ReLU(), + nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构 + nn.BatchNorm2d(conv_channels*4), + nn.ReLU(), + # nn.MaxPool2d(2), # GraphConvLayer(128, 256), # 图卷积层 nn.AdaptiveMaxPool2d(1) ) # 多任务预测头 self.vision_head = nn.Sequential( - nn.Linear(conv_channels*2, 1), + nn.Linear(conv_channels*8, 1), nn.Sigmoid() ) self.topo_head = nn.Sequential( - nn.Linear(conv_channels, 1), + nn.Linear(conv_channels*4, 1), nn.Sigmoid() ) diff --git a/minamo/train.py b/minamo/train.py index 7143a72..28ff107 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -33,21 +33,21 @@ def collate_fn(batch): ) def train(): - print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.") + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") model = MinamoModel(32) model.to(device) # 准备数据集 - dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-dataset.json") - val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") + dataset = MinamoDataset("minamo-dataset.json") + val_dataset = MinamoDataset("minamo-eval.json") dataloader = DataLoader( dataset, - batch_size=32, + batch_size=64, shuffle=True ) val_loader = DataLoader( val_dataset, - batch_size=32, + batch_size=64, shuffle=True ) @@ -98,7 +98,7 @@ def train(): scheduler.step() # 每十轮推理一次验证集 - if (epoch + 1) % 10 == 0: + if (epoch + 1) % 5 == 0: model.eval() val_loss = 0 with torch.no_grad():