feat: 增加 Minamo Model 模型深度

This commit is contained in:
unanmed 2025-03-22 11:56:14 +08:00
parent c9c52109ed
commit eb0626ef88
8 changed files with 37 additions and 13 deletions

View File

@ -335,13 +335,13 @@ class GinkaLoss(nn.Module):
losses = [
minamo_loss * self.weight[0],
border_loss * self.weight[1],
border_loss * self.weight[1] * 0.1,
entrance_loss * self.weight[2],
count_loss * self.weight[3],
illegal_loss * self.weight[4]
]
# 梯度归一化
# scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(losses)
return total_loss
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(scaled_losses)
return total_loss, sum(losses)

View File

@ -78,12 +78,12 @@ def train():
output = model(feat_vec)
# 计算损失
loss = criterion(output, target, target_vision_feat, target_topo_feat)
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
# 反向传播
loss.backward()
scaled_losses.backward()
optimizer.step()
total_loss += loss.item()
total_loss += losses.item()
avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
@ -119,8 +119,8 @@ def train():
print(torch.argmax(output, dim=1)[0])
# 计算损失
loss = criterion(output, target, target_vision_feat, target_topo_feat)
loss_val += loss.item()
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
loss_val += losses.item()
avg_val_loss = loss_val / len(dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")

View File

@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)
@ -113,7 +113,7 @@ def validate():
idx += 1
# 计算损失
loss = criterion(output, target, target_vision_feat, target_topo_feat)
_, loss = criterion(output, target, target_vision_feat, target_topo_feat)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)

View File

@ -12,4 +12,5 @@ class MinamoLoss(nn.Module):
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
vis_loss = self.mse(vis_pred, vis_true)
topo_loss = self.mse(topo_pred, topo_true)
# print(vis_loss.item(), topo_loss.item())
return self.vision_weight * vis_loss + self.topo_weight * topo_loss

View File

@ -14,11 +14,15 @@ class MinamoTopoModel(nn.Module):
# 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4, dropout=0.3)
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
# 正则化
self.norm1 = nn.LayerNorm(hidden_dim*16)
self.norm2 = nn.LayerNorm(hidden_dim*16)
self.norm_ins2 = nn.LayerNorm(hidden_dim*16)
self.norm_ins1 = nn.LayerNorm(hidden_dim*16)
self.norm3 = nn.LayerNorm(out_dim)
# 池化层
@ -40,6 +44,12 @@ class MinamoTopoModel(nn.Module):
x = self.conv2(x, graph.edge_index)
x = F.elu(self.norm2(x))
x = self.conv_ins2(x, graph.edge_index)
x = F.elu(self.norm_ins2(x))
x = self.conv_ins1(x, graph.edge_index)
x = F.elu(self.norm_ins1(x))
x = self.conv3(x, graph.edge_index)
x = F.elu(self.norm3(x))

View File

@ -30,6 +30,11 @@ class MinamoVisionModel(nn.Module):
CBAM(conv_ch*8),
nn.GELU(),
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
nn.BatchNorm2d(conv_ch*8),
CBAM(conv_ch*8),
nn.GELU(),
nn.AdaptiveMaxPool2d(1)
)

View File

@ -61,16 +61,24 @@ def train():
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
model.load_state_dict(data["model_state"], strict=False)
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
# for name, param in model.named_parameters():
# if 'ins' not in name: # 仅训练扩展部分
# param.requires_grad = False
# 开始训练
for epoch in tqdm(range(args.epochs)):
model.train()
total_loss = 0
# if epoch == 30:
# for name, param in model.named_parameters():
# param.requires_grad = True
for batch in dataloader:
# 数据迁移到设备
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch

View File

@ -15,7 +15,7 @@ def validate():
model.to(device)
# 准备数据集
val_dataset = MinamoDataset("minamo-eval.json")
val_dataset = MinamoDataset("datasets/minamo-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,