mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
Merge branch 'master' of github.com:unanmed/ginka-generator
This commit is contained in:
commit
787ccc4af8
@ -25,10 +25,10 @@ class ConditionEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, tag, val, stage):
|
def forward(self, tag, val, stage):
|
||||||
tag = self.tag_embed(tag)
|
# tag = self.tag_embed(tag)
|
||||||
val = self.val_embed(val)
|
val = self.val_embed(val)
|
||||||
stage = self.stage_embed(stage)
|
stage = self.stage_embed(stage)
|
||||||
feat = torch.stack([tag, val, stage], dim=1)
|
feat = torch.stack([val, stage], dim=1)
|
||||||
feat = self.encoder(feat)
|
feat = self.encoder(feat)
|
||||||
feat = torch.mean(feat, dim=1)
|
feat = torch.mean(feat, dim=1)
|
||||||
feat = self.fusion(feat)
|
feat = self.fusion(feat)
|
||||||
|
|||||||
@ -128,8 +128,8 @@ def train():
|
|||||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||||
|
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1)
|
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2)
|
||||||
|
|
||||||
criterion = WGANGinkaLoss()
|
criterion = WGANGinkaLoss()
|
||||||
|
|
||||||
@ -196,6 +196,8 @@ def train():
|
|||||||
dis_total = torch.Tensor([0]).to(device)
|
dis_total = torch.Tensor([0]).to(device)
|
||||||
loss_ce_total = torch.Tensor([0]).to(device)
|
loss_ce_total = torch.Tensor([0]).to(device)
|
||||||
|
|
||||||
|
iters = 0
|
||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
rand = batch["rand"].to(device)
|
rand = batch["rand"].to(device)
|
||||||
real0 = batch["real0"].to(device)
|
real0 = batch["real0"].to(device)
|
||||||
@ -280,12 +282,27 @@ def train():
|
|||||||
optimizer_ginka.step()
|
optimizer_ginka.step()
|
||||||
loss_total_ginka += loss_g.detach()
|
loss_total_ginka += loss_g.detach()
|
||||||
|
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
if iters % 50 == 0:
|
||||||
|
avg_loss_ginka = loss_total_ginka.item() / iters / g_steps
|
||||||
|
avg_loss_minamo = loss_total_minamo.item() / iters / c_steps
|
||||||
|
avg_loss_ce = loss_ce_total.item() / iters / g_steps
|
||||||
|
avg_dis = dis_total.item() / iters / c_steps
|
||||||
|
tqdm.write(
|
||||||
|
f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||||
|
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||||
|
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||||
|
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
||||||
|
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||||
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
||||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||||
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
||||||
@ -368,14 +385,14 @@ def train():
|
|||||||
|
|
||||||
# 训练流程控制
|
# 训练流程控制
|
||||||
|
|
||||||
if train_stage >= 2:
|
# if train_stage >= 2:
|
||||||
# train_stage = 4
|
# # train_stage = 4
|
||||||
if (epoch + 1) % 100 == 5:
|
# if (epoch + 1) % 100 == 5:
|
||||||
train_stage = 3
|
# train_stage = 3
|
||||||
elif (epoch + 1) % 100 == 20:
|
# elif (epoch + 1) % 100 == 20:
|
||||||
train_stage = 4
|
# train_stage = 4
|
||||||
elif (epoch + 1) % 100 == 0:
|
# elif (epoch + 1) % 100 == 0:
|
||||||
train_stage = 2
|
# train_stage = 2
|
||||||
|
|
||||||
if train_stage == 1:
|
if train_stage == 1:
|
||||||
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
||||||
@ -385,7 +402,7 @@ def train():
|
|||||||
|
|
||||||
stage_epoch = 0
|
stage_epoch = 0
|
||||||
if mask_ratio >= 0.8:
|
if mask_ratio >= 0.8:
|
||||||
train_stage = 2
|
train_stage = 4
|
||||||
|
|
||||||
stage_epoch += 1
|
stage_epoch += 1
|
||||||
total_epoch += 1
|
total_epoch += 1
|
||||||
|
|||||||
2
train.sh
2
train.sh
@ -1,4 +1,4 @@
|
|||||||
# 从头训练
|
# 从头训练
|
||||||
python3 -u -m ginka.train_wgan --epochs 300 >> output.log
|
python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log
|
||||||
# 接续训练
|
# 接续训练
|
||||||
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user