mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 适配新的数据集
This commit is contained in:
parent
224005b44b
commit
310b3fae80
@ -25,10 +25,10 @@ class ConditionEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, tag, val, stage):
|
||||
tag = self.tag_embed(tag)
|
||||
# tag = self.tag_embed(tag)
|
||||
val = self.val_embed(val)
|
||||
stage = self.stage_embed(stage)
|
||||
feat = torch.stack([tag, val, stage], dim=1)
|
||||
feat = torch.stack([val, stage], dim=1)
|
||||
feat = self.encoder(feat)
|
||||
feat = torch.mean(feat, dim=1)
|
||||
feat = self.fusion(feat)
|
||||
|
||||
@ -128,8 +128,8 @@ def train():
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2)
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
@ -196,6 +196,8 @@ def train():
|
||||
dis_total = torch.Tensor([0]).to(device)
|
||||
loss_ce_total = torch.Tensor([0]).to(device)
|
||||
|
||||
iters = 0
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
rand = batch["rand"].to(device)
|
||||
real0 = batch["real0"].to(device)
|
||||
@ -280,12 +282,27 @@ def train():
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
|
||||
iters += 1
|
||||
|
||||
if iters % 50 == 0:
|
||||
avg_loss_ginka = loss_total_ginka.item() / iters / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / iters / c_steps
|
||||
avg_loss_ce = loss_ce_total.item() / iters / g_steps
|
||||
avg_dis = dis_total.item() / iters / c_steps
|
||||
tqdm.write(
|
||||
f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
)
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
||||
@ -368,14 +385,14 @@ def train():
|
||||
|
||||
# 训练流程控制
|
||||
|
||||
if train_stage >= 2:
|
||||
# train_stage = 4
|
||||
if (epoch + 1) % 100 == 5:
|
||||
train_stage = 3
|
||||
elif (epoch + 1) % 100 == 20:
|
||||
train_stage = 4
|
||||
elif (epoch + 1) % 100 == 0:
|
||||
train_stage = 2
|
||||
# if train_stage >= 2:
|
||||
# # train_stage = 4
|
||||
# if (epoch + 1) % 100 == 5:
|
||||
# train_stage = 3
|
||||
# elif (epoch + 1) % 100 == 20:
|
||||
# train_stage = 4
|
||||
# elif (epoch + 1) % 100 == 0:
|
||||
# train_stage = 2
|
||||
|
||||
if train_stage == 1:
|
||||
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
||||
@ -385,7 +402,7 @@ def train():
|
||||
|
||||
stage_epoch = 0
|
||||
if mask_ratio >= 0.8:
|
||||
train_stage = 2
|
||||
train_stage = 4
|
||||
|
||||
stage_epoch += 1
|
||||
total_epoch += 1
|
||||
|
||||
2
train.sh
2
train.sh
@ -1,4 +1,4 @@
|
||||
# 从头训练
|
||||
python3 -u -m ginka.train_wgan --epochs 300 >> output.log
|
||||
python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log
|
||||
# 接续训练
|
||||
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
||||
|
||||
Loading…
Reference in New Issue
Block a user