fix: GINKA train issue

This commit is contained in:
unanmed 2025-03-19 21:32:48 +08:00
parent fd72b1e7f4
commit a801d6e357
3 changed files with 10 additions and 17 deletions

View File

@ -15,12 +15,6 @@ os.makedirs("result/ginka_checkpoint", exist_ok=True)
epochs = 150
def update_tau(epoch):
start_tau = 1.0
min_tau = 0.1
decay_rate = 0.95
return max(min_tau, start_tau * (decay_rate ** epoch))
# 在生成器输出后添加梯度检查钩子
def grad_hook(module, grad_input, grad_output):
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
@ -73,10 +67,10 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
optimizer.zero_grad()
output, output_softmax = model(feat_vec)
output = model(feat_vec)
# 计算损失
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
loss = criterion(output, target, target_vision_feat, target_topo_feat)
# 反向传播
loss.backward()
@ -113,11 +107,11 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output, output_softmax = model(feat_vec)
print(output_softmax[0])
output = model(feat_vec)
print(torch.argmax(output, dim=1)[0])
# 计算损失
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
loss = criterion(output, target, target_vision_feat, target_topo_feat)
loss_val += loss.item()
avg_val_loss = loss_val / len(dataloader_val)
@ -136,5 +130,5 @@ def train():
}, f"result/ginka.pth")
if __name__ == "__main__":
torch.set_num_threads(8)
torch.set_num_threads(4)
train()

View File

@ -6,7 +6,6 @@ from minamo.model.model import MinamoModel
from .dataset import GinkaDataset
from .model.loss import GinkaLoss
from .model.model import GinkaModel
from shared.graph import DynamicGraphConverter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -26,10 +25,10 @@ def validate():
shuffle=True
)
converter = DynamicGraphConverter().to(device)
criterion = GinkaLoss(minamo, converter)
criterion = GinkaLoss(minamo)
minamo.eval()
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
@ -43,7 +42,7 @@ def validate():
map_matrix = torch.argmax(output, dim=1)
# 计算损失
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat)
loss = criterion(output, target, target_vision_feat, target_topo_feat)
total_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)

View File

@ -129,7 +129,7 @@ def train():
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)