mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
fix: GINKA train issue
This commit is contained in:
parent
fd72b1e7f4
commit
a801d6e357
@ -15,12 +15,6 @@ os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 150
|
||||
|
||||
def update_tau(epoch):
|
||||
start_tau = 1.0
|
||||
min_tau = 0.1
|
||||
decay_rate = 0.95
|
||||
return max(min_tau, start_tau * (decay_rate ** epoch))
|
||||
|
||||
# 在生成器输出后添加梯度检查钩子
|
||||
def grad_hook(module, grad_input, grad_output):
|
||||
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
|
||||
@ -73,10 +67,10 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
output, output_softmax = model(feat_vec)
|
||||
output = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
@ -113,11 +107,11 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
print(output_softmax[0])
|
||||
output = model(feat_vec)
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += loss.item()
|
||||
|
||||
avg_val_loss = loss_val / len(dataloader_val)
|
||||
@ -136,5 +130,5 @@ def train():
|
||||
}, f"result/ginka.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(8)
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
|
||||
@ -6,7 +6,6 @@ from minamo.model.model import MinamoModel
|
||||
from .dataset import GinkaDataset
|
||||
from .model.loss import GinkaLoss
|
||||
from .model.model import GinkaModel
|
||||
from shared.graph import DynamicGraphConverter
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -26,10 +25,10 @@ def validate():
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
converter = DynamicGraphConverter().to(device)
|
||||
criterion = GinkaLoss(minamo, converter)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
minamo.eval()
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
@ -43,7 +42,7 @@ def validate():
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat)
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
@ -129,7 +129,7 @@ def train():
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user