feat: 添加 js 散度损失

This commit is contained in:
unanmed 2025-04-16 18:39:17 +08:00
parent 87bada8553
commit 9b5abf177c
7 changed files with 78 additions and 54 deletions

View File

@ -9,15 +9,14 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
对于 HTML5 魔塔,如果你想要贡献数据集,需要对你的魔塔进行手动数据处理,流程如下:
1. 选择楼层,可以是剧情层、战斗层等,但是需要满足下述条件
2. 楼层除边缘外不应出现墙壁堆叠(例如 2\*2边缘可以有重叠
3. 楼层中不应该有闲置怪,不应该在直线上有无间隔连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口
4. 最外面一层围上一圈墙壁(箭头楼层切换除外)
5. 将所有的墙壁换成黄墙(数字 1
6. 将所有的血瓶换成红血瓶(数字 31所有红宝石换成最基础的红宝石数字 27蓝宝石换成最基础的蓝宝石数字 28道具全部换为幸运金币数字 53剑盾可以当成红蓝宝石看待删除除此之外的资源
7. 所有钥匙换成黄钥匙(数字 21所有门换成黄门数字 81
8. 所有箭头换成样板原版箭头(数字 91 至 94所有上下楼梯换成样板原版楼梯数字 87 和 88
9. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201中怪换成红头怪数字 202强怪换成青头怪数字 203
10. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板:
2. 楼层中不应该有闲置怪,不应该在直线上有无间隔连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口
3. 最外面一层围上一圈墙壁(箭头楼层切换除外)
4. 将所有的墙壁换成黄墙(数字 1
5. 将所有的血瓶换成红血瓶(数字 31所有红宝石换成最基础的红宝石数字 27蓝宝石换成最基础的蓝宝石数字 28绿宝石换成最基础的绿宝石数字 29道具全部换为幸运金币数字 53剑盾可以当成红蓝宝石看待删除除此之外的资源或者换成允许的资源
6. 所有钥匙换成黄钥匙(数字 21所有门换成黄门数字 81
7. 所有箭头换成样板原版箭头(数字 91 至 94所有上下楼梯换成样板原版楼梯数字 87 和 88
8. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201中怪换成红头怪数字 202强怪换成青头怪数字 203
9. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板:
```json
{
@ -33,5 +32,5 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
其中,`clip` 属性表示你的每张地图的那一部分会被当成数据集,例如填写 `[0, 0, 13, 13]` 就会让坐标为 `(0, 0)`,长宽为 `(13, 13)` 的矩形内容作为数据集。`special` 属性允许你针对单独的某几层设置不同的裁剪方式,例如设置 `MT11``[3, 3, 7, 7]` 等,如果没有设置默认使用 `defaults` 的裁剪方式。最好保证每个楼层大小一致,不然我还要手动分类。
11. 在全塔属性中的楼层列表中去除不在数据集内的楼层
12. 将 `project` 文件夹打包发给我即可
10. 在全塔属性中的楼层列表中去除不在数据集内的楼层
11. 将 `project` 文件夹打包发给我即可

View File

@ -15,7 +15,8 @@ const numMap: Record<number, number> = {
92: 11, // 箭头
93: 11, // 箭头
94: 11, // 箭头
53: 12 // 道具
53: 12, // 道具
29: 13, // 绿宝石
};
const apeiriaMap: Record<number, number> = {
@ -27,7 +28,7 @@ const apeiriaMap: Record<number, number> = {
23: 2, // 红钥匙
27: 3, // 红宝石
28: 4, // 蓝宝石
29: 0, // 绿宝石
29: 13, // 绿宝石
31: 5, // 红血瓶
32: 5, // 蓝血瓶
33: 5, // 绿血瓶

View File

@ -13,7 +13,7 @@ export const tileType = new Set(
);
const branchType = new Set([6, 7, 8, 9]);
const entranceType = new Set([10, 11]);
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12]);
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12, 13]);
export const directions: [number, number][] = [
[-1, 0],

View File

@ -9,10 +9,10 @@ from typing import List
from shared.utils import random_smooth_onehot
STAGE1_MASK = [0, 1, 10, 11]
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12]
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12, 13]
STAGE2_MASK = [6, 7, 8, 9]
STAGE2_REMOVE = [2, 3, 4, 5, 12]
STAGE3_MASK = [2, 3, 4, 5, 12]
STAGE2_REMOVE = [2, 3, 4, 5, 12, 13]
STAGE3_MASK = [2, 3, 4, 5, 12, 13]
STAGE3_REMOVE = []
def load_data(path: str):
@ -65,7 +65,7 @@ def apply_curriculum_mask(
# Step 2: 对指定类别随机遮挡
for cls in mask_classes:
cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W]
cls_mask = masked_maps[cls] > 0 # 目标类别的像素布尔掩码 [H, W]
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
num_mask = int(len(indices) * mask_ratio)
if num_mask > 0:
@ -139,7 +139,17 @@ class GinkaWGANDataset(Dataset):
return self.handle_stage3(target)
elif self.train_stage == 4:
return self.handle_stage4(target)
self.mask_ratio1 = self.mask_ratio2 = self.mask_ratio3 = random.uniform(0, 0.9)
self.random_ratio = 0.2
mode = random.choices([1, 2, 3, 4], weights=[0.2, 0.2, 0.2, 0.4])
if mode == 1:
return self.handle_stage1(target)
elif mode == 2:
return self.handle_stage2(target)
elif mode == 3:
return self.handle_stage3(target)
else:
return self.handle_stage4(target)
raise RuntimeError(f"Invalid train stage: {self.train_stage}")

View File

@ -11,13 +11,13 @@ from shared.similarity.topo import overall_similarity, build_topological_graph
from shared.similarity.vision import calculate_visual_similarity
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 12
ILLEGAL_MAX_NUM = 13
STAGE_ALLOWED = [
[],
[0, 1, 10, 11],
[6, 7, 8, 9,],
[2, 3, 4, 5, 12]
[6, 7, 8, 9],
[2, 3, 4, 5, 12, 13]
]
def get_not_allowed(classes: list[int], include_illegal=False):
@ -302,11 +302,14 @@ def js_divergence(p, q, eps=1e-8):
# log_softmax 以供 kl_div 使用
log_p = torch.log(p + eps)
log_q = torch.log(q + eps)
log_m = torch.log(m + eps)
nn.KLDivLoss
kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m)
kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m)
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
return 0.5 * (kl_pm + kl_qm)
def immutable_penalty_loss(
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
@ -332,8 +335,8 @@ def immutable_penalty_loss(
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]):
# weight: 判别器损失L1 损失,不可修改类型损失
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]):
# weight: 判别器损失L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
@ -399,7 +402,7 @@ class WGANGinkaLoss:
return vis_sim, topo_sim
def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """
fake_graph = batch_convert_soft_map_to_graph(fake)
@ -409,11 +412,14 @@ class WGANGinkaLoss:
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
immutable_loss * self.weight[2],
constraint_loss * self.weight[3]
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b) * self.weight[5],
]
if stage == 1:
@ -421,7 +427,7 @@ class WGANGinkaLoss:
entrance_loss = entrance_constraint_loss(fake)
losses.append(entrance_loss * self.weight[4])
# print(losses[2].item())
# print(-js_divergence(fake_a, fake_b).item())
return sum(losses), minamo_loss, ce_loss, immutable_loss
@ -433,10 +439,13 @@ class WGANGinkaLoss:
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3]
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b) * self.weight[5],
]
if stage == 1:

View File

@ -70,9 +70,9 @@ def train():
# 1 代表课程学习阶段2 代表课程学习后,逐渐转为联合学习的阶段
# 3 代表课程学习后的联合遮挡学习阶段4 代表最后随机输入的联合学习阶段
train_stage = 1
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
random_ratio = 0
stage3_epoch = 0 # 第三阶段 epoch 数,100 轮后进入第四阶段
stage3_epoch = 0 # 第三阶段 epoch 数,若干轮后进入第四阶段
ginka = GinkaModel()
minamo = MinamoScoreModule()
@ -216,9 +216,9 @@ def train():
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " +
f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " +
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
)
if avg_loss_ce < 0.1:
@ -226,23 +226,24 @@ def train():
else:
low_loss_epochs = 0
if low_loss_epochs >= 5 and train_stage == 2:
if low_loss_epochs >= 3 and train_stage == 2:
if random_ratio >= 0.5:
train_stage = 3
random_ratio += 0.1
random_ratio += 0.2
random_ratio = min(random_ratio, 0.5)
low_loss_epochs = 0
if low_loss_epochs >= 5 and train_stage == 1:
if low_loss_epochs >= 3 and train_stage == 1:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.1
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
if train_stage == 3:
stage3_epoch += 1
if stage3_epoch >= 100:
# 十轮足够了
if stage3_epoch >= 10:
train_stage = 4
stage3_epoch = 0
@ -250,8 +251,8 @@ def train():
# 第二阶段后 L1 损失不再应该生效
mask_ratio = 1.0
dataset.train_stage = 2
dataset_val.train_stage = 2
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.random_ratio = random_ratio
dataset_val.random_ratio = random_ratio
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
@ -292,19 +293,23 @@ def train():
with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
if train_stage == 1:
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
for i in range(fake1.shape[0]):
for key, one in enumerate([fake1, fake2, fake3]):
map_matrix = one[i]
image = matrix_to_image_cv(map_matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
for i in range(fake1.shape[0]):
for key, one in enumerate([fake1, fake2, fake3]):
map_matrix = one[i]
image = matrix_to_image_cv(map_matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
idx += 1
idx += 1
print("Train ended.")
torch.save({

View File

@ -44,7 +44,7 @@ class GinkaTopologicalGraphs:
TILE_TYPE = set(range(13))
BRANCH_TYPE = {6, 7, 8, 9}
ENTRANCE_TYPE = {10, 11}
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12}
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12, 13}
directions: List[Tuple[int, int]] = [
(-1, 0), (1, 0), (0, -1), (0, 1)