diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 58795fc..faefebf 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=1, topo_weight=0): + def __init__(self, vision_weight=0.4, topo_weight=0.6): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 07a0941..9e34ab1 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GATConv, AttentionalAggregation, global_max_pool +from torch_geometric.nn import GATConv, global_max_pool from torch_geometric.data import Data class MinamoTopoModel(nn.Module):