From 10e6ad63944046596cc952af6525f680569ef487 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 30 Apr 2025 21:35:53 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=AE=AD=E7=BB=83=E5=BE=AA=E7=8E=AF=20&?= =?UTF-8?q?=20=E6=80=AA=E7=89=A9=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/src/floor.ts | 6 +++--- ginka/train_wgan.py | 42 +++++++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/data/src/floor.ts b/data/src/floor.ts index f8d1096..4f74d81 100644 --- a/data/src/floor.ts +++ b/data/src/floor.ts @@ -239,11 +239,11 @@ function convert( const attr = (enemy.atk + enemy.def) * enemy.hp; const ad = attr - minAttr; if (ad < delta / 3 || delta === 0) { - res[ny][nx] = 25; - } else if (ad < (delta * 2) / 3) { res[ny][nx] = 26; - } else { + } else if (ad < (delta * 2) / 3) { res[ny][nx] = 27; + } else { + res[ny][nx] = 28; } } } diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 1f7ecc7..10c73c9 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -71,9 +71,9 @@ def parse_arguments(): return args def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - fake1, _ = gen(masked1, 1, False, tag, val) - fake2, _ = gen(masked2, 2, False, tag, val) - fake3, _ = gen(masked3, 3, False, tag, val) + fake1, _ = gen(masked1, 1, tag, val) + fake2, _ = gen(masked2, 2, tag, val) + fake3, _ = gen(masked3, 3, tag, val) if detach: return fake1.detach(), fake2.detach(), fake3.detach() else: @@ -81,13 +81,13 @@ def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tu def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: if progress_detach: - fake1, x_in = gen(input.detach(), 1, random, tag, val) - fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, False, tag, val) - fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, False, tag, val) + fake1, x_in = gen(input.detach(), 1, tag, val, random) + fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val) + fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val) else: - fake1, x_in = gen(input, 1, random, tag, val) - fake2, _ = gen(F.softmax(fake1, dim=1), 2, False, tag, val) - fake3, _ = gen(F.softmax(fake2, dim=1), 3, False, tag, val) + fake1, x_in = gen(input, 1, tag, val, random) + fake2, _ = gen(F.softmax(fake1, dim=1), 2, tag, val) + fake3, _ = gen(F.softmax(fake2, dim=1), 3, tag, val) if result_detach: return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach() else: @@ -205,9 +205,9 @@ def train(): elif train_stage == 3 or train_stage == 4: fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) - loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) - loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) - loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3) + loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond) + loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond) + loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond) dis_avg = (dis1 + dis2 + dis3) / 3.0 loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 @@ -226,11 +226,11 @@ def train(): optimizer_minamo.zero_grad() optimizer_ginka.zero_grad() if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False) + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False) - loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1) - loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2) - loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3) + loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond) + loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond) + loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond) loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3) @@ -241,14 +241,14 @@ def train(): loss_ce_total += loss_ce.detach() elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, x_in = gen_total(ginka, masked1, True, False, train_stage == 4) + fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4) if train_stage == 3: - loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1) + loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1, tag_cond, val_cond) else: - loss_g1 = criterion.generator_loss_total(minamo, 1, fake1) - loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1) - loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2) + loss_g1 = criterion.generator_loss_total(minamo, 1, fake1, tag_cond, val_cond) + loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond) + loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond) if train_stage == 4: loss_head = criterion.generator_input_head_loss(x_in)