From eb0626ef88052a4c523086442a579df8bec31ec9 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 22 Mar 2025 11:56:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=20Minamo=20Model=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=B7=B1=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/model/loss.py | 8 ++++---- ginka/train.py | 10 +++++----- ginka/validate.py | 4 ++-- minamo/model/loss.py | 1 + minamo/model/topo.py | 10 ++++++++++ minamo/model/vision.py | 5 +++++ minamo/train.py | 10 +++++++++- minamo/validate.py | 2 +- 8 files changed, 37 insertions(+), 13 deletions(-) diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 9d90951..fa96174 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -335,13 +335,13 @@ class GinkaLoss(nn.Module): losses = [ minamo_loss * self.weight[0], - border_loss * self.weight[1], + border_loss * self.weight[1] * 0.1, entrance_loss * self.weight[2], count_loss * self.weight[3], illegal_loss * self.weight[4] ] # 梯度归一化 - # scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] - total_loss = sum(losses) - return total_loss + scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] + total_loss = sum(scaled_losses) + return total_loss, sum(losses) diff --git a/ginka/train.py b/ginka/train.py index b4ec39f..8e71222 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -78,12 +78,12 @@ def train(): output = model(feat_vec) # 计算损失 - loss = criterion(output, target, target_vision_feat, target_topo_feat) + scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat) # 反向传播 - loss.backward() + scaled_losses.backward() optimizer.step() - total_loss += loss.item() + total_loss += losses.item() avg_loss = total_loss / len(dataloader) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") @@ -119,8 +119,8 @@ def train(): print(torch.argmax(output, dim=1)[0]) # 计算损失 - loss = criterion(output, target, target_vision_feat, target_topo_feat) - loss_val += loss.item() + scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat) + loss_val += losses.item() avg_val_loss = loss_val / len(dataloader_val) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") diff --git a/ginka/validate.py b/ginka/validate.py index 60318d8..c35e461 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() - state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"] + state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"] model.load_state_dict(state) model.to(device) @@ -113,7 +113,7 @@ def validate(): idx += 1 # 计算损失 - loss = criterion(output, target, target_vision_feat, target_topo_feat) + _, loss = criterion(output, target, target_vision_feat, target_topo_feat) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) diff --git a/minamo/model/loss.py b/minamo/model/loss.py index fe99bca..c60a948 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -12,4 +12,5 @@ class MinamoLoss(nn.Module): # print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item()) vis_loss = self.mse(vis_pred, vis_true) topo_loss = self.mse(topo_pred, topo_true) + # print(vis_loss.item(), topo_loss.item()) return self.vision_weight * vis_loss + self.topo_weight * topo_loss diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 9a4a00e..f20cdb2 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -14,11 +14,15 @@ class MinamoTopoModel(nn.Module): # 图卷积层 self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2) self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) + self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4, dropout=0.3) + self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2) self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False) # 正则化 self.norm1 = nn.LayerNorm(hidden_dim*16) self.norm2 = nn.LayerNorm(hidden_dim*16) + self.norm_ins2 = nn.LayerNorm(hidden_dim*16) + self.norm_ins1 = nn.LayerNorm(hidden_dim*16) self.norm3 = nn.LayerNorm(out_dim) # 池化层 @@ -40,6 +44,12 @@ class MinamoTopoModel(nn.Module): x = self.conv2(x, graph.edge_index) x = F.elu(self.norm2(x)) + x = self.conv_ins2(x, graph.edge_index) + x = F.elu(self.norm_ins2(x)) + + x = self.conv_ins1(x, graph.edge_index) + x = F.elu(self.norm_ins1(x)) + x = self.conv3(x, graph.edge_index) x = F.elu(self.norm3(x)) diff --git a/minamo/model/vision.py b/minamo/model/vision.py index fc39a16..35126d9 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -30,6 +30,11 @@ class MinamoVisionModel(nn.Module): CBAM(conv_ch*8), nn.GELU(), + nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1), + nn.BatchNorm2d(conv_ch*8), + CBAM(conv_ch*8), + nn.GELU(), + nn.AdaptiveMaxPool2d(1) ) diff --git a/minamo/train.py b/minamo/train.py index cb90c9c..e84bb14 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -61,16 +61,24 @@ def train(): if args.resume: data = torch.load(args.from_state, map_location=device) - model.load_state_dict(data["model_state"]) + model.load_state_dict(data["model_state"], strict=False) if args.load_optim: optimizer.load_state_dict(data["optimizer_state"]) print("Train from loaded state.") + + # for name, param in model.named_parameters(): + # if 'ins' not in name: # 仅训练扩展部分 + # param.requires_grad = False # 开始训练 for epoch in tqdm(range(args.epochs)): model.train() total_loss = 0 + # if epoch == 30: + # for name, param in model.named_parameters(): + # param.requires_grad = True + for batch in dataloader: # 数据迁移到设备 map1, map2, vision_simi, topo_simi, graph1, graph2 = batch diff --git a/minamo/validate.py b/minamo/validate.py index 1af2731..7b70f59 100644 --- a/minamo/validate.py +++ b/minamo/validate.py @@ -15,7 +15,7 @@ def validate(): model.to(device) # 准备数据集 - val_dataset = MinamoDataset("minamo-eval.json") + val_dataset = MinamoDataset("datasets/minamo-eval.json") val_loader = DataLoader( val_dataset, batch_size=32,