From a801d6e357861ab35676f304c06b204635c1a5f4 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 19 Mar 2025 21:32:48 +0800 Subject: [PATCH] fix: GINKA train issue --- ginka/train.py | 18 ++++++------------ ginka/validate.py | 7 +++---- minamo/train.py | 2 +- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/ginka/train.py b/ginka/train.py index aadff42..3fcce09 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -15,12 +15,6 @@ os.makedirs("result/ginka_checkpoint", exist_ok=True) epochs = 150 -def update_tau(epoch): - start_tau = 1.0 - min_tau = 0.1 - decay_rate = 0.95 - return max(min_tau, start_tau * (decay_rate ** epoch)) - # 在生成器输出后添加梯度检查钩子 def grad_hook(module, grad_input, grad_output): print(f"Generator output grad norm: {grad_output[0].norm().item()}") @@ -73,10 +67,10 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 optimizer.zero_grad() - output, output_softmax = model(feat_vec) + output = model(feat_vec) # 计算损失 - loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat) + loss = criterion(output, target, target_vision_feat, target_topo_feat) # 反向传播 loss.backward() @@ -113,11 +107,11 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 - output, output_softmax = model(feat_vec) - print(output_softmax[0]) + output = model(feat_vec) + print(torch.argmax(output, dim=1)[0]) # 计算损失 - loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat) + loss = criterion(output, target, target_vision_feat, target_topo_feat) loss_val += loss.item() avg_val_loss = loss_val / len(dataloader_val) @@ -136,5 +130,5 @@ def train(): }, f"result/ginka.pth") if __name__ == "__main__": - torch.set_num_threads(8) + torch.set_num_threads(4) train() diff --git a/ginka/validate.py b/ginka/validate.py index 220c649..71eaf59 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -6,7 +6,6 @@ from minamo.model.model import MinamoModel from .dataset import GinkaDataset from .model.loss import GinkaLoss from .model.model import GinkaModel -from shared.graph import DynamicGraphConverter device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -26,10 +25,10 @@ def validate(): shuffle=True ) - converter = DynamicGraphConverter().to(device) - criterion = GinkaLoss(minamo, converter) + criterion = GinkaLoss(minamo) minamo.eval() + model.eval() val_loss = 0 with torch.no_grad(): for batch in val_loader: @@ -43,7 +42,7 @@ def validate(): map_matrix = torch.argmax(output, dim=1) # 计算损失 - loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat) + loss = criterion(output, target, target_vision_feat, target_topo_feat) total_loss += loss.item() avg_val_loss = val_loss / len(val_loader) diff --git a/minamo/train.py b/minamo/train.py index 44c2aa5..dbe2775 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -129,7 +129,7 @@ def train(): topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) # 计算损失 - loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi) + loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val) val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader)