diff --git a/README.md b/README.md index 5038e0b..7b1d315 100644 --- a/README.md +++ b/README.md @@ -9,15 +9,14 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对 对于 HTML5 魔塔,如果你想要贡献数据集,需要对你的魔塔进行手动数据处理,流程如下: 1. 选择楼层,可以是剧情层、战斗层等,但是需要满足下述条件 -2. 楼层除边缘外不应出现墙壁堆叠(例如 2\*2,边缘可以有重叠) -3. 楼层中不应该有闲置怪,不应该在直线上有无间隔连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口 -4. 最外面一层围上一圈墙壁(箭头楼层切换除外) -5. 将所有的墙壁换成黄墙(数字 1) -6. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),道具全部换为幸运金币(数字 53),剑盾可以当成红蓝宝石看待,删除除此之外的资源 -7. 所有钥匙换成黄钥匙(数字 21),所有门换成黄门(数字 81) -8. 所有箭头换成样板原版箭头(数字 91 至 94),所有上下楼梯换成样板原版楼梯(数字 87 和 88) -9. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201),中怪换成红头怪(数字 202),强怪换成青头怪(数字 203) -10. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板: +2. 楼层中不应该有闲置怪,不应该在直线上有无间隔连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口 +3. 最外面一层围上一圈墙壁(箭头楼层切换除外) +4. 将所有的墙壁换成黄墙(数字 1) +5. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),绿宝石换成最基础的绿宝石(数字 29),道具全部换为幸运金币(数字 53),剑盾可以当成红蓝宝石看待,删除除此之外的资源(或者换成允许的资源) +6. 所有钥匙换成黄钥匙(数字 21),所有门换成黄门(数字 81) +7. 所有箭头换成样板原版箭头(数字 91 至 94),所有上下楼梯换成样板原版楼梯(数字 87 和 88) +8. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201),中怪换成红头怪(数字 202),强怪换成青头怪(数字 203) +9. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板: ```json { @@ -33,5 +32,5 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对 其中,`clip` 属性表示你的每张地图的那一部分会被当成数据集,例如填写 `[0, 0, 13, 13]` 就会让坐标为 `(0, 0)`,长宽为 `(13, 13)` 的矩形内容作为数据集。`special` 属性允许你针对单独的某几层设置不同的裁剪方式,例如设置 `MT11` 为 `[3, 3, 7, 7]` 等,如果没有设置默认使用 `defaults` 的裁剪方式。最好保证每个楼层大小一致,不然我还要手动分类。 -11. 在全塔属性中的楼层列表中去除不在数据集内的楼层 -12. 将 `project` 文件夹打包发给我即可 +10. 在全塔属性中的楼层列表中去除不在数据集内的楼层 +11. 将 `project` 文件夹打包发给我即可 diff --git a/data/src/floor.ts b/data/src/floor.ts index 8fcec9b..5c3bba8 100644 --- a/data/src/floor.ts +++ b/data/src/floor.ts @@ -15,7 +15,8 @@ const numMap: Record = { 92: 11, // 箭头 93: 11, // 箭头 94: 11, // 箭头 - 53: 12 // 道具 + 53: 12, // 道具 + 29: 13, // 绿宝石 }; const apeiriaMap: Record = { @@ -27,7 +28,7 @@ const apeiriaMap: Record = { 23: 2, // 红钥匙 27: 3, // 红宝石 28: 4, // 蓝宝石 - 29: 0, // 绿宝石 + 29: 13, // 绿宝石 31: 5, // 红血瓶 32: 5, // 蓝血瓶 33: 5, // 绿血瓶 diff --git a/data/src/topology/graph.ts b/data/src/topology/graph.ts index 5345f08..db62d62 100644 --- a/data/src/topology/graph.ts +++ b/data/src/topology/graph.ts @@ -13,7 +13,7 @@ export const tileType = new Set( ); const branchType = new Set([6, 7, 8, 9]); const entranceType = new Set([10, 11]); -const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12]); +const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12, 13]); export const directions: [number, number][] = [ [-1, 0], diff --git a/ginka/dataset.py b/ginka/dataset.py index 866b2c7..21c0906 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -9,10 +9,10 @@ from typing import List from shared.utils import random_smooth_onehot STAGE1_MASK = [0, 1, 10, 11] -STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12] +STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12, 13] STAGE2_MASK = [6, 7, 8, 9] -STAGE2_REMOVE = [2, 3, 4, 5, 12] -STAGE3_MASK = [2, 3, 4, 5, 12] +STAGE2_REMOVE = [2, 3, 4, 5, 12, 13] +STAGE3_MASK = [2, 3, 4, 5, 12, 13] STAGE3_REMOVE = [] def load_data(path: str): @@ -65,7 +65,7 @@ def apply_curriculum_mask( # Step 2: 对指定类别随机遮挡 for cls in mask_classes: - cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W] + cls_mask = masked_maps[cls] > 0 # 目标类别的像素布尔掩码 [H, W] indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标 num_mask = int(len(indices) * mask_ratio) if num_mask > 0: @@ -139,7 +139,17 @@ class GinkaWGANDataset(Dataset): return self.handle_stage3(target) elif self.train_stage == 4: - return self.handle_stage4(target) + self.mask_ratio1 = self.mask_ratio2 = self.mask_ratio3 = random.uniform(0, 0.9) + self.random_ratio = 0.2 + mode = random.choices([1, 2, 3, 4], weights=[0.2, 0.2, 0.2, 0.4]) + if mode == 1: + return self.handle_stage1(target) + elif mode == 2: + return self.handle_stage2(target) + elif mode == 3: + return self.handle_stage3(target) + else: + return self.handle_stage4(target) raise RuntimeError(f"Invalid train stage: {self.train_stage}") \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 6ce7c1f..a154f9c 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -11,13 +11,13 @@ from shared.similarity.topo import overall_similarity, build_topological_graph from shared.similarity.vision import calculate_visual_similarity CLASS_NUM = 32 -ILLEGAL_MAX_NUM = 12 +ILLEGAL_MAX_NUM = 13 STAGE_ALLOWED = [ [], [0, 1, 10, 11], - [6, 7, 8, 9,], - [2, 3, 4, 5, 12] + [6, 7, 8, 9], + [2, 3, 4, 5, 12, 13] ] def get_not_allowed(classes: list[int], include_illegal=False): @@ -302,11 +302,14 @@ def js_divergence(p, q, eps=1e-8): # log_softmax 以供 kl_div 使用 log_p = torch.log(p + eps) log_q = torch.log(q + eps) + log_m = torch.log(m + eps) + + nn.KLDivLoss - kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m) - kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m) + kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m) + kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m) - return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0) + return 0.5 * (kl_pm + kl_qm) def immutable_penalty_loss( pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] @@ -332,8 +335,8 @@ def immutable_penalty_loss( return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]): - # weight: 判别器损失,L1 损失,不可修改类型损失 + def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]): + # weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight @@ -399,7 +402,7 @@ class WGANGinkaLoss: return vis_sim, topo_sim - def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 生成器损失函数 """ fake_graph = batch_convert_soft_map_to_graph(fake) @@ -409,11 +412,14 @@ class WGANGinkaLoss: immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) + fake_a, fake_b = fake.chunk(2, dim=0) + losses = [ minamo_loss * self.weight[0], ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 immutable_loss * self.weight[2], - constraint_loss * self.weight[3] + constraint_loss * self.weight[3], + -js_divergence(fake_a, fake_b) * self.weight[5], ] if stage == 1: @@ -421,7 +427,7 @@ class WGANGinkaLoss: entrance_loss = entrance_constraint_loss(fake) losses.append(entrance_loss * self.weight[4]) - # print(losses[2].item()) + # print(-js_divergence(fake_a, fake_b).item()) return sum(losses), minamo_loss, ce_loss, immutable_loss @@ -433,10 +439,13 @@ class WGANGinkaLoss: immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) + fake_a, fake_b = fake.chunk(2, dim=0) + losses = [ minamo_loss * self.weight[0], immutable_loss * self.weight[2], - constraint_loss * self.weight[3] + constraint_loss * self.weight[3], + -js_divergence(fake_a, fake_b) * self.weight[5], ] if stage == 1: diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 983566b..8f0aebf 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -70,9 +70,9 @@ def train(): # 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段 # 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段 train_stage = 1 - mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 + mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 random_ratio = 0 - stage3_epoch = 0 # 第三阶段 epoch 数,100 轮后进入第四阶段 + stage3_epoch = 0 # 第三阶段 epoch 数,若干轮后进入第四阶段 ginka = GinkaModel() minamo = MinamoScoreModule() @@ -216,9 +216,9 @@ def train(): avg_dis = dis_total.item() / len(dataloader) / c_steps tqdm.write( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " + - f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " + - f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}" + f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + + f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + + f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}" ) if avg_loss_ce < 0.1: @@ -226,23 +226,24 @@ def train(): else: low_loss_epochs = 0 - if low_loss_epochs >= 5 and train_stage == 2: + if low_loss_epochs >= 3 and train_stage == 2: if random_ratio >= 0.5: train_stage = 3 - random_ratio += 0.1 + random_ratio += 0.2 random_ratio = min(random_ratio, 0.5) low_loss_epochs = 0 - if low_loss_epochs >= 5 and train_stage == 1: + if low_loss_epochs >= 3 and train_stage == 1: if mask_ratio >= 0.9: train_stage = 2 - mask_ratio += 0.1 + mask_ratio += 0.2 mask_ratio = min(mask_ratio, 0.9) low_loss_epochs = 0 if train_stage == 3: stage3_epoch += 1 - if stage3_epoch >= 100: + # 十轮足够了 + if stage3_epoch >= 10: train_stage = 4 stage3_epoch = 0 @@ -250,8 +251,8 @@ def train(): # 第二阶段后 L1 损失不再应该生效 mask_ratio = 1.0 - dataset.train_stage = 2 - dataset_val.train_stage = 2 + dataset.train_stage = train_stage + dataset_val.train_stage = train_stage dataset.random_ratio = random_ratio dataset_val.random_ratio = random_ratio dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio @@ -292,19 +293,23 @@ def train(): with torch.no_grad(): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] - if train_stage == 1: + if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) - fake1 = torch.argmax(fake1, dim=1).cpu().numpy() - fake2 = torch.argmax(fake2, dim=1).cpu().numpy() - fake3 = torch.argmax(fake3, dim=1).cpu().numpy() + + elif train_stage == 3 or train_stage == 4: + fake1, fake2, fake3 = gen_total(ginka, masked1, True, True) - for i in range(fake1.shape[0]): - for key, one in enumerate([fake1, fake2, fake3]): - map_matrix = one[i] - image = matrix_to_image_cv(map_matrix, tile_dict) - cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image) + fake1 = torch.argmax(fake1, dim=1).cpu().numpy() + fake2 = torch.argmax(fake2, dim=1).cpu().numpy() + fake3 = torch.argmax(fake3, dim=1).cpu().numpy() + + for i in range(fake1.shape[0]): + for key, one in enumerate([fake1, fake2, fake3]): + map_matrix = one[i] + image = matrix_to_image_cv(map_matrix, tile_dict) + cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image) - idx += 1 + idx += 1 print("Train ended.") torch.save({ diff --git a/shared/similarity/topo.py b/shared/similarity/topo.py index 3fe80ce..03e97ef 100644 --- a/shared/similarity/topo.py +++ b/shared/similarity/topo.py @@ -44,7 +44,7 @@ class GinkaTopologicalGraphs: TILE_TYPE = set(range(13)) BRANCH_TYPE = {6, 7, 8, 9} ENTRANCE_TYPE = {10, 11} -RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12} +RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12, 13} directions: List[Tuple[int, int]] = [ (-1, 0), (1, 0), (0, -1), (0, 1)