diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 4ada733..6c292b8 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0, topo_weight=1): + def __init__(self, vision_weight=0.4, topo_weight=0.6): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/model.py b/minamo/model/model.py index 1ea8ad0..84797bb 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -45,7 +45,7 @@ class DirectionalAttention(nn.Module): return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1) class MinamoModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256): + def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32): super().__init__() # 嵌入层处理不同图块类型 self.embedding = nn.Embedding(tile_types, embedding_dim) @@ -57,25 +57,43 @@ class MinamoModel(nn.Module): nn.ReLU(), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), DualAttention(conv_channels*2), + nn.BatchNorm2d(conv_channels*2), + nn.ReLU(), + nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), + DualAttention(conv_channels*4), + nn.BatchNorm2d(conv_channels*4), + nn.ReLU(), + nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1), + DualAttention(conv_channels*8), + nn.BatchNorm2d(conv_channels*8), + nn.ReLU(), nn.AdaptiveAvgPool2d(1) ) # 拓扑特征分支 self.topo_conv = nn.Sequential( nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构 - nn.MaxPool2d(2), + nn.BatchNorm2d(conv_channels), + nn.ReLU(), + nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构 + nn.BatchNorm2d(conv_channels*2), + nn.ReLU(), + nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构 + nn.BatchNorm2d(conv_channels*4), + nn.ReLU(), + # nn.MaxPool2d(2), # GraphConvLayer(128, 256), # 图卷积层 nn.AdaptiveMaxPool2d(1) ) # 多任务预测头 self.vision_head = nn.Sequential( - nn.Linear(conv_channels*2, 1), + nn.Linear(conv_channels*8, 1), nn.Sigmoid() ) self.topo_head = nn.Sequential( - nn.Linear(conv_channels, 1), + nn.Linear(conv_channels*4, 1), nn.Sigmoid() ) diff --git a/minamo/train.py b/minamo/train.py index 7143a72..28ff107 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -33,21 +33,21 @@ def collate_fn(batch): ) def train(): - print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.") + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") model = MinamoModel(32) model.to(device) # 准备数据集 - dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-dataset.json") - val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") + dataset = MinamoDataset("minamo-dataset.json") + val_dataset = MinamoDataset("minamo-eval.json") dataloader = DataLoader( dataset, - batch_size=32, + batch_size=64, shuffle=True ) val_loader = DataLoader( val_dataset, - batch_size=32, + batch_size=64, shuffle=True ) @@ -98,7 +98,7 @@ def train(): scheduler.step() # 每十轮推理一次验证集 - if (epoch + 1) % 10 == 0: + if (epoch + 1) % 5 == 0: model.eval() val_loss = 0 with torch.no_grad():