From def3cc5c3f94007991a4ed86057dea8e09e00217 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 25 Mar 2025 21:31:56 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8D=9F=E5=A4=B1=E5=80=BC=E6=9D=83?= =?UTF-8?q?=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/model/loss.py | 2 +- minamo/model/topo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 58795fc..faefebf 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=1, topo_weight=0): + def __init__(self, vision_weight=0.4, topo_weight=0.6): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 07a0941..9e34ab1 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GATConv, AttentionalAggregation, global_max_pool +from torch_geometric.nn import GATConv, global_max_pool from torch_geometric.data import Data class MinamoTopoModel(nn.Module):