fix: GINKA train issue

This commit is contained in:
unanmed 2025-03-19 21:32:48 +08:00
parent fd72b1e7f4
commit a801d6e357
3 changed files with 10 additions and 17 deletions

View File

@ -15,12 +15,6 @@ os.makedirs("result/ginka_checkpoint", exist_ok=True)
epochs = 150 epochs = 150
def update_tau(epoch):
start_tau = 1.0
min_tau = 0.1
decay_rate = 0.95
return max(min_tau, start_tau * (decay_rate ** epoch))
# 在生成器输出后添加梯度检查钩子 # 在生成器输出后添加梯度检查钩子
def grad_hook(module, grad_input, grad_output): def grad_hook(module, grad_input, grad_output):
print(f"Generator output grad norm: {grad_output[0].norm().item()}") print(f"Generator output grad norm: {grad_output[0].norm().item()}")
@ -73,10 +67,10 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播 # 前向传播
optimizer.zero_grad() optimizer.zero_grad()
output, output_softmax = model(feat_vec) output = model(feat_vec)
# 计算损失 # 计算损失
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat) loss = criterion(output, target, target_vision_feat, target_topo_feat)
# 反向传播 # 反向传播
loss.backward() loss.backward()
@ -113,11 +107,11 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播 # 前向传播
output, output_softmax = model(feat_vec) output = model(feat_vec)
print(output_softmax[0]) print(torch.argmax(output, dim=1)[0])
# 计算损失 # 计算损失
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat) loss = criterion(output, target, target_vision_feat, target_topo_feat)
loss_val += loss.item() loss_val += loss.item()
avg_val_loss = loss_val / len(dataloader_val) avg_val_loss = loss_val / len(dataloader_val)
@ -136,5 +130,5 @@ def train():
}, f"result/ginka.pth") }, f"result/ginka.pth")
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(8) torch.set_num_threads(4)
train() train()

View File

@ -6,7 +6,6 @@ from minamo.model.model import MinamoModel
from .dataset import GinkaDataset from .dataset import GinkaDataset
from .model.loss import GinkaLoss from .model.loss import GinkaLoss
from .model.model import GinkaModel from .model.model import GinkaModel
from shared.graph import DynamicGraphConverter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -26,10 +25,10 @@ def validate():
shuffle=True shuffle=True
) )
converter = DynamicGraphConverter().to(device) criterion = GinkaLoss(minamo)
criterion = GinkaLoss(minamo, converter)
minamo.eval() minamo.eval()
model.eval()
val_loss = 0 val_loss = 0
with torch.no_grad(): with torch.no_grad():
for batch in val_loader: for batch in val_loader:
@ -43,7 +42,7 @@ def validate():
map_matrix = torch.argmax(output, dim=1) map_matrix = torch.argmax(output, dim=1)
# 计算损失 # 计算损失
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat) loss = criterion(output, target, target_vision_feat, target_topo_feat)
total_loss += loss.item() total_loss += loss.item()
avg_val_loss = val_loss / len(val_loader) avg_val_loss = val_loss / len(val_loader)

View File

@ -129,7 +129,7 @@ def train():
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失 # 计算损失
loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi) loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
val_loss += loss_val.item() val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader) avg_val_loss = val_loss / len(val_loader)