mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 优化网络参数
This commit is contained in:
parent
98f7a9cdcf
commit
0910bddba2
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=0, topo_weight=1):
|
||||
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
|
||||
@ -45,7 +45,7 @@ class DirectionalAttention(nn.Module):
|
||||
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32):
|
||||
super().__init__()
|
||||
# 嵌入层处理不同图块类型
|
||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||
@ -57,25 +57,43 @@ class MinamoModel(nn.Module):
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
DualAttention(conv_channels*2),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
DualAttention(conv_channels*4),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1),
|
||||
DualAttention(conv_channels*8),
|
||||
nn.BatchNorm2d(conv_channels*8),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
|
||||
# 拓扑特征分支
|
||||
self.topo_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
nn.ReLU(),
|
||||
# nn.MaxPool2d(2),
|
||||
# GraphConvLayer(128, 256), # 图卷积层
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
)
|
||||
|
||||
# 多任务预测头
|
||||
self.vision_head = nn.Sequential(
|
||||
nn.Linear(conv_channels*2, 1),
|
||||
nn.Linear(conv_channels*8, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.topo_head = nn.Sequential(
|
||||
nn.Linear(conv_channels, 1),
|
||||
nn.Linear(conv_channels*4, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
|
||||
@ -33,21 +33,21 @@ def collate_fn(batch):
|
||||
)
|
||||
|
||||
def train():
|
||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
model = MinamoModel(32)
|
||||
model.to(device)
|
||||
|
||||
# 准备数据集
|
||||
dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-dataset.json")
|
||||
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
|
||||
dataset = MinamoDataset("minamo-dataset.json")
|
||||
val_dataset = MinamoDataset("minamo-eval.json")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
batch_size=64,
|
||||
shuffle=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
batch_size=64,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
@ -98,7 +98,7 @@ def train():
|
||||
scheduler.step()
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 10 == 0:
|
||||
if (epoch + 1) % 5 == 0:
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user