mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 05:01:41 +08:00
fix: 损失值权重
This commit is contained in:
parent
8872de3f13
commit
def3cc5c3f
@ -1,7 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class MinamoLoss(nn.Module):
|
class MinamoLoss(nn.Module):
|
||||||
def __init__(self, vision_weight=1, topo_weight=0):
|
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vision_weight = vision_weight
|
self.vision_weight = vision_weight
|
||||||
self.topo_weight = topo_weight
|
self.topo_weight = topo_weight
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_geometric.nn import GATConv, AttentionalAggregation, global_max_pool
|
from torch_geometric.nn import GATConv, global_max_pool
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
class MinamoTopoModel(nn.Module):
|
class MinamoTopoModel(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user