fix: 损失值权重

This commit is contained in:
unanmed 2025-03-25 21:31:56 +08:00
parent 8872de3f13
commit def3cc5c3f
2 changed files with 2 additions and 2 deletions

View File

@ -1,7 +1,7 @@
import torch.nn as nn
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=1, topo_weight=0):
def __init__(self, vision_weight=0.4, topo_weight=0.6):
super().__init__()
self.vision_weight = vision_weight
self.topo_weight = topo_weight

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, AttentionalAggregation, global_max_pool
from torch_geometric.nn import GATConv, global_max_pool
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):