feat: 适配新的数据集

This commit is contained in:
unanmed 2025-12-11 22:55:55 +08:00
parent 224005b44b
commit 310b3fae80
3 changed files with 32 additions and 15 deletions

View File

@ -25,10 +25,10 @@ class ConditionEncoder(nn.Module):
)
def forward(self, tag, val, stage):
tag = self.tag_embed(tag)
# tag = self.tag_embed(tag)
val = self.val_embed(val)
stage = self.stage_embed(stage)
feat = torch.stack([tag, val, stage], dim=1)
feat = torch.stack([val, stage], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)

View File

@ -128,8 +128,8 @@ def train():
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2)
criterion = WGANGinkaLoss()
@ -196,6 +196,8 @@ def train():
dis_total = torch.Tensor([0]).to(device)
loss_ce_total = torch.Tensor([0]).to(device)
iters = 0
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
rand = batch["rand"].to(device)
real0 = batch["real0"].to(device)
@ -280,12 +282,27 @@ def train():
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
iters += 1
if iters % 50 == 0:
avg_loss_ginka = loss_total_ginka.item() / iters / g_steps
avg_loss_minamo = loss_total_minamo.item() / iters / c_steps
avg_loss_ce = loss_ce_total.item() / iters / g_steps
avg_dis = dis_total.item() / iters / c_steps
tqdm.write(
f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
@ -368,14 +385,14 @@ def train():
# 训练流程控制
if train_stage >= 2:
# train_stage = 4
if (epoch + 1) % 100 == 5:
train_stage = 3
elif (epoch + 1) % 100 == 20:
train_stage = 4
elif (epoch + 1) % 100 == 0:
train_stage = 2
# if train_stage >= 2:
# # train_stage = 4
# if (epoch + 1) % 100 == 5:
# train_stage = 3
# elif (epoch + 1) % 100 == 20:
# train_stage = 4
# elif (epoch + 1) % 100 == 0:
# train_stage = 2
if train_stage == 1:
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
@ -385,7 +402,7 @@ def train():
stage_epoch = 0
if mask_ratio >= 0.8:
train_stage = 2
train_stage = 4
stage_epoch += 1
total_epoch += 1

View File

@ -1,4 +1,4 @@
# 从头训练
python3 -u -m ginka.train_wgan --epochs 300 >> output.log
python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log
# 接续训练
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log