fix: 训练循环 & 怪物数据处理

This commit is contained in:
unanmed 2025-04-30 21:35:53 +08:00
parent 1cded87530
commit 10e6ad6394
2 changed files with 24 additions and 24 deletions

View File

@ -239,11 +239,11 @@ function convert(
const attr = (enemy.atk + enemy.def) * enemy.hp;
const ad = attr - minAttr;
if (ad < delta / 3 || delta === 0) {
res[ny][nx] = 25;
} else if (ad < (delta * 2) / 3) {
res[ny][nx] = 26;
} else {
} else if (ad < (delta * 2) / 3) {
res[ny][nx] = 27;
} else {
res[ny][nx] = 28;
}
}
}

View File

@ -71,9 +71,9 @@ def parse_arguments():
return args
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1, _ = gen(masked1, 1, False, tag, val)
fake2, _ = gen(masked2, 2, False, tag, val)
fake3, _ = gen(masked3, 3, False, tag, val)
fake1, _ = gen(masked1, 1, tag, val)
fake2, _ = gen(masked2, 2, tag, val)
fake3, _ = gen(masked3, 3, tag, val)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
@ -81,13 +81,13 @@ def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tu
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if progress_detach:
fake1, x_in = gen(input.detach(), 1, random, tag, val)
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, False, tag, val)
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, False, tag, val)
fake1, x_in = gen(input.detach(), 1, tag, val, random)
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
else:
fake1, x_in = gen(input, 1, random, tag, val)
fake2, _ = gen(F.softmax(fake1, dim=1), 2, False, tag, val)
fake3, _ = gen(F.softmax(fake2, dim=1), 3, False, tag, val)
fake1, x_in = gen(input, 1, tag, val, random)
fake2, _ = gen(F.softmax(fake1, dim=1), 2, tag, val)
fake3, _ = gen(F.softmax(fake2, dim=1), 3, tag, val)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
else:
@ -205,9 +205,9 @@ def train():
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
@ -226,11 +226,11 @@ def train():
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False)
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
@ -241,14 +241,14 @@ def train():
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, True, False, train_stage == 4)
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
if train_stage == 4:
loss_head = criterion.generator_input_head_loss(x_in)