mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
fix: 损失值权重
This commit is contained in:
parent
8872de3f13
commit
def3cc5c3f
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=1, topo_weight=0):
|
||||
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GATConv, AttentionalAggregation, global_max_pool
|
||||
from torch_geometric.nn import GATConv, global_max_pool
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user