mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
perf: 优化网络参数
This commit is contained in:
parent
98f7a9cdcf
commit
0910bddba2
@ -1,7 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class MinamoLoss(nn.Module):
|
class MinamoLoss(nn.Module):
|
||||||
def __init__(self, vision_weight=0, topo_weight=1):
|
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vision_weight = vision_weight
|
self.vision_weight = vision_weight
|
||||||
self.topo_weight = topo_weight
|
self.topo_weight = topo_weight
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class DirectionalAttention(nn.Module):
|
|||||||
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
|
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
|
||||||
|
|
||||||
class MinamoModel(nn.Module):
|
class MinamoModel(nn.Module):
|
||||||
def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256):
|
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 嵌入层处理不同图块类型
|
# 嵌入层处理不同图块类型
|
||||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||||
@ -57,25 +57,43 @@ class MinamoModel(nn.Module):
|
|||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||||
DualAttention(conv_channels*2),
|
DualAttention(conv_channels*2),
|
||||||
|
nn.BatchNorm2d(conv_channels*2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||||
|
DualAttention(conv_channels*4),
|
||||||
|
nn.BatchNorm2d(conv_channels*4),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1),
|
||||||
|
DualAttention(conv_channels*8),
|
||||||
|
nn.BatchNorm2d(conv_channels*8),
|
||||||
|
nn.ReLU(),
|
||||||
nn.AdaptiveAvgPool2d(1)
|
nn.AdaptiveAvgPool2d(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 拓扑特征分支
|
# 拓扑特征分支
|
||||||
self.topo_conv = nn.Sequential(
|
self.topo_conv = nn.Sequential(
|
||||||
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
|
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
|
||||||
nn.MaxPool2d(2),
|
nn.BatchNorm2d(conv_channels),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构
|
||||||
|
nn.BatchNorm2d(conv_channels*2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构
|
||||||
|
nn.BatchNorm2d(conv_channels*4),
|
||||||
|
nn.ReLU(),
|
||||||
|
# nn.MaxPool2d(2),
|
||||||
# GraphConvLayer(128, 256), # 图卷积层
|
# GraphConvLayer(128, 256), # 图卷积层
|
||||||
nn.AdaptiveMaxPool2d(1)
|
nn.AdaptiveMaxPool2d(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 多任务预测头
|
# 多任务预测头
|
||||||
self.vision_head = nn.Sequential(
|
self.vision_head = nn.Sequential(
|
||||||
nn.Linear(conv_channels*2, 1),
|
nn.Linear(conv_channels*8, 1),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.topo_head = nn.Sequential(
|
self.topo_head = nn.Sequential(
|
||||||
nn.Linear(conv_channels, 1),
|
nn.Linear(conv_channels*4, 1),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -33,21 +33,21 @@ def collate_fn(batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
model = MinamoModel(32)
|
model = MinamoModel(32)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-dataset.json")
|
dataset = MinamoDataset("minamo-dataset.json")
|
||||||
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
|
val_dataset = MinamoDataset("minamo-eval.json")
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=32,
|
batch_size=64,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=32,
|
batch_size=64,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def train():
|
|||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# 每十轮推理一次验证集
|
# 每十轮推理一次验证集
|
||||||
if (epoch + 1) % 10 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user