mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 训练循环 & 怪物数据处理
This commit is contained in:
parent
1cded87530
commit
10e6ad6394
@ -239,11 +239,11 @@ function convert(
|
||||
const attr = (enemy.atk + enemy.def) * enemy.hp;
|
||||
const ad = attr - minAttr;
|
||||
if (ad < delta / 3 || delta === 0) {
|
||||
res[ny][nx] = 25;
|
||||
} else if (ad < (delta * 2) / 3) {
|
||||
res[ny][nx] = 26;
|
||||
} else {
|
||||
} else if (ad < (delta * 2) / 3) {
|
||||
res[ny][nx] = 27;
|
||||
} else {
|
||||
res[ny][nx] = 28;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,9 +71,9 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
fake1, _ = gen(masked1, 1, False, tag, val)
|
||||
fake2, _ = gen(masked2, 2, False, tag, val)
|
||||
fake3, _ = gen(masked3, 3, False, tag, val)
|
||||
fake1, _ = gen(masked1, 1, tag, val)
|
||||
fake2, _ = gen(masked2, 2, tag, val)
|
||||
fake3, _ = gen(masked3, 3, tag, val)
|
||||
if detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
else:
|
||||
@ -81,13 +81,13 @@ def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tu
|
||||
|
||||
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
||||
if progress_detach:
|
||||
fake1, x_in = gen(input.detach(), 1, random, tag, val)
|
||||
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, False, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, False, tag, val)
|
||||
fake1, x_in = gen(input.detach(), 1, tag, val, random)
|
||||
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
|
||||
else:
|
||||
fake1, x_in = gen(input, 1, random, tag, val)
|
||||
fake2, _ = gen(F.softmax(fake1, dim=1), 2, False, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2, dim=1), 3, False, tag, val)
|
||||
fake1, x_in = gen(input, 1, tag, val, random)
|
||||
fake2, _ = gen(F.softmax(fake1, dim=1), 2, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2, dim=1), 3, tag, val)
|
||||
if result_detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
|
||||
else:
|
||||
@ -205,9 +205,9 @@ def train():
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
|
||||
|
||||
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||
@ -226,11 +226,11 @@ def train():
|
||||
optimizer_minamo.zero_grad()
|
||||
optimizer_ginka.zero_grad()
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False)
|
||||
|
||||
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
|
||||
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
|
||||
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
|
||||
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
|
||||
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
|
||||
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
||||
@ -241,14 +241,14 @@ def train():
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, True, False, train_stage == 4)
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
|
||||
|
||||
if train_stage == 3:
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond)
|
||||
else:
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
|
||||
|
||||
if train_stage == 4:
|
||||
loss_head = criterion.generator_input_head_loss(x_in)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user