mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 添加 js 散度损失
This commit is contained in:
parent
87bada8553
commit
9b5abf177c
21
README.md
21
README.md
@ -9,15 +9,14 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
|
||||
对于 HTML5 魔塔,如果你想要贡献数据集,需要对你的魔塔进行手动数据处理,流程如下:
|
||||
|
||||
1. 选择楼层,可以是剧情层、战斗层等,但是需要满足下述条件
|
||||
2. 楼层除边缘外不应出现墙壁堆叠(例如 2\*2,边缘可以有重叠)
|
||||
3. 楼层中不应该有闲置怪,不应该在直线上有无间隔连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口
|
||||
4. 最外面一层围上一圈墙壁(箭头楼层切换除外)
|
||||
5. 将所有的墙壁换成黄墙(数字 1)
|
||||
6. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),道具全部换为幸运金币(数字 53),剑盾可以当成红蓝宝石看待,删除除此之外的资源
|
||||
7. 所有钥匙换成黄钥匙(数字 21),所有门换成黄门(数字 81)
|
||||
8. 所有箭头换成样板原版箭头(数字 91 至 94),所有上下楼梯换成样板原版楼梯(数字 87 和 88)
|
||||
9. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201),中怪换成红头怪(数字 202),强怪换成青头怪(数字 203)
|
||||
10. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板:
|
||||
2. 楼层中不应该有闲置怪,不应该在直线上有无间隔连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口
|
||||
3. 最外面一层围上一圈墙壁(箭头楼层切换除外)
|
||||
4. 将所有的墙壁换成黄墙(数字 1)
|
||||
5. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),绿宝石换成最基础的绿宝石(数字 29),道具全部换为幸运金币(数字 53),剑盾可以当成红蓝宝石看待,删除除此之外的资源(或者换成允许的资源)
|
||||
6. 所有钥匙换成黄钥匙(数字 21),所有门换成黄门(数字 81)
|
||||
7. 所有箭头换成样板原版箭头(数字 91 至 94),所有上下楼梯换成样板原版楼梯(数字 87 和 88)
|
||||
8. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201),中怪换成红头怪(数字 202),强怪换成青头怪(数字 203)
|
||||
9. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板:
|
||||
|
||||
```json
|
||||
{
|
||||
@ -33,5 +32,5 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
|
||||
|
||||
其中,`clip` 属性表示你的每张地图的那一部分会被当成数据集,例如填写 `[0, 0, 13, 13]` 就会让坐标为 `(0, 0)`,长宽为 `(13, 13)` 的矩形内容作为数据集。`special` 属性允许你针对单独的某几层设置不同的裁剪方式,例如设置 `MT11` 为 `[3, 3, 7, 7]` 等,如果没有设置默认使用 `defaults` 的裁剪方式。最好保证每个楼层大小一致,不然我还要手动分类。
|
||||
|
||||
11. 在全塔属性中的楼层列表中去除不在数据集内的楼层
|
||||
12. 将 `project` 文件夹打包发给我即可
|
||||
10. 在全塔属性中的楼层列表中去除不在数据集内的楼层
|
||||
11. 将 `project` 文件夹打包发给我即可
|
||||
|
||||
@ -15,7 +15,8 @@ const numMap: Record<number, number> = {
|
||||
92: 11, // 箭头
|
||||
93: 11, // 箭头
|
||||
94: 11, // 箭头
|
||||
53: 12 // 道具
|
||||
53: 12, // 道具
|
||||
29: 13, // 绿宝石
|
||||
};
|
||||
|
||||
const apeiriaMap: Record<number, number> = {
|
||||
@ -27,7 +28,7 @@ const apeiriaMap: Record<number, number> = {
|
||||
23: 2, // 红钥匙
|
||||
27: 3, // 红宝石
|
||||
28: 4, // 蓝宝石
|
||||
29: 0, // 绿宝石
|
||||
29: 13, // 绿宝石
|
||||
31: 5, // 红血瓶
|
||||
32: 5, // 蓝血瓶
|
||||
33: 5, // 绿血瓶
|
||||
|
||||
@ -13,7 +13,7 @@ export const tileType = new Set(
|
||||
);
|
||||
const branchType = new Set([6, 7, 8, 9]);
|
||||
const entranceType = new Set([10, 11]);
|
||||
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12]);
|
||||
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12, 13]);
|
||||
|
||||
export const directions: [number, number][] = [
|
||||
[-1, 0],
|
||||
|
||||
@ -9,10 +9,10 @@ from typing import List
|
||||
from shared.utils import random_smooth_onehot
|
||||
|
||||
STAGE1_MASK = [0, 1, 10, 11]
|
||||
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12]
|
||||
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12, 13]
|
||||
STAGE2_MASK = [6, 7, 8, 9]
|
||||
STAGE2_REMOVE = [2, 3, 4, 5, 12]
|
||||
STAGE3_MASK = [2, 3, 4, 5, 12]
|
||||
STAGE2_REMOVE = [2, 3, 4, 5, 12, 13]
|
||||
STAGE3_MASK = [2, 3, 4, 5, 12, 13]
|
||||
STAGE3_REMOVE = []
|
||||
|
||||
def load_data(path: str):
|
||||
@ -65,7 +65,7 @@ def apply_curriculum_mask(
|
||||
|
||||
# Step 2: 对指定类别随机遮挡
|
||||
for cls in mask_classes:
|
||||
cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W]
|
||||
cls_mask = masked_maps[cls] > 0 # 目标类别的像素布尔掩码 [H, W]
|
||||
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
|
||||
num_mask = int(len(indices) * mask_ratio)
|
||||
if num_mask > 0:
|
||||
@ -139,7 +139,17 @@ class GinkaWGANDataset(Dataset):
|
||||
return self.handle_stage3(target)
|
||||
|
||||
elif self.train_stage == 4:
|
||||
return self.handle_stage4(target)
|
||||
self.mask_ratio1 = self.mask_ratio2 = self.mask_ratio3 = random.uniform(0, 0.9)
|
||||
self.random_ratio = 0.2
|
||||
mode = random.choices([1, 2, 3, 4], weights=[0.2, 0.2, 0.2, 0.4])
|
||||
if mode == 1:
|
||||
return self.handle_stage1(target)
|
||||
elif mode == 2:
|
||||
return self.handle_stage2(target)
|
||||
elif mode == 3:
|
||||
return self.handle_stage3(target)
|
||||
else:
|
||||
return self.handle_stage4(target)
|
||||
|
||||
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
|
||||
|
||||
@ -11,13 +11,13 @@ from shared.similarity.topo import overall_similarity, build_topological_graph
|
||||
from shared.similarity.vision import calculate_visual_similarity
|
||||
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 12
|
||||
ILLEGAL_MAX_NUM = 13
|
||||
|
||||
STAGE_ALLOWED = [
|
||||
[],
|
||||
[0, 1, 10, 11],
|
||||
[6, 7, 8, 9,],
|
||||
[2, 3, 4, 5, 12]
|
||||
[6, 7, 8, 9],
|
||||
[2, 3, 4, 5, 12, 13]
|
||||
]
|
||||
|
||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
@ -302,11 +302,14 @@ def js_divergence(p, q, eps=1e-8):
|
||||
# log_softmax 以供 kl_div 使用
|
||||
log_p = torch.log(p + eps)
|
||||
log_q = torch.log(q + eps)
|
||||
log_m = torch.log(m + eps)
|
||||
|
||||
nn.KLDivLoss
|
||||
|
||||
kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m)
|
||||
kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m)
|
||||
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
||||
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
||||
|
||||
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
|
||||
return 0.5 * (kl_pm + kl_qm)
|
||||
|
||||
def immutable_penalty_loss(
|
||||
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
||||
@ -332,8 +335,8 @@ def immutable_penalty_loss(
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]):
|
||||
# weight: 判别器损失,L1 损失,不可修改类型损失
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]):
|
||||
# weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -399,7 +402,7 @@ class WGANGinkaLoss:
|
||||
|
||||
return vis_sim, topo_sim
|
||||
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" 生成器损失函数 """
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
|
||||
@ -409,11 +412,14 @@ class WGANGinkaLoss:
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3]
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -421,7 +427,7 @@ class WGANGinkaLoss:
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
# print(losses[2].item())
|
||||
# print(-js_divergence(fake_a, fake_b).item())
|
||||
|
||||
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||
|
||||
@ -433,10 +439,13 @@ class WGANGinkaLoss:
|
||||
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3]
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
|
||||
@ -70,9 +70,9 @@ def train():
|
||||
# 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段
|
||||
# 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段
|
||||
train_stage = 1
|
||||
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
random_ratio = 0
|
||||
stage3_epoch = 0 # 第三阶段 epoch 数,100 轮后进入第四阶段
|
||||
stage3_epoch = 0 # 第三阶段 epoch 数,若干轮后进入第四阶段
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
@ -216,9 +216,9 @@ def train():
|
||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " +
|
||||
f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " +
|
||||
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
|
||||
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 0.1:
|
||||
@ -226,23 +226,24 @@ def train():
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 2:
|
||||
if low_loss_epochs >= 3 and train_stage == 2:
|
||||
if random_ratio >= 0.5:
|
||||
train_stage = 3
|
||||
random_ratio += 0.1
|
||||
random_ratio += 0.2
|
||||
random_ratio = min(random_ratio, 0.5)
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1:
|
||||
if low_loss_epochs >= 3 and train_stage == 1:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.1
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage == 3:
|
||||
stage3_epoch += 1
|
||||
if stage3_epoch >= 100:
|
||||
# 十轮足够了
|
||||
if stage3_epoch >= 10:
|
||||
train_stage = 4
|
||||
stage3_epoch = 0
|
||||
|
||||
@ -250,8 +251,8 @@ def train():
|
||||
# 第二阶段后 L1 损失不再应该生效
|
||||
mask_ratio = 1.0
|
||||
|
||||
dataset.train_stage = 2
|
||||
dataset_val.train_stage = 2
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset.random_ratio = random_ratio
|
||||
dataset_val.random_ratio = random_ratio
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
@ -292,19 +293,23 @@ def train():
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
if train_stage == 1:
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||
|
||||
for i in range(fake1.shape[0]):
|
||||
for key, one in enumerate([fake1, fake2, fake3]):
|
||||
map_matrix = one[i]
|
||||
image = matrix_to_image_cv(map_matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
||||
|
||||
for i in range(fake1.shape[0]):
|
||||
for key, one in enumerate([fake1, fake2, fake3]):
|
||||
map_matrix = one[i]
|
||||
image = matrix_to_image_cv(map_matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
||||
|
||||
idx += 1
|
||||
idx += 1
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
|
||||
@ -44,7 +44,7 @@ class GinkaTopologicalGraphs:
|
||||
TILE_TYPE = set(range(13))
|
||||
BRANCH_TYPE = {6, 7, 8, 9}
|
||||
ENTRANCE_TYPE = {10, 11}
|
||||
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12}
|
||||
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12, 13}
|
||||
|
||||
directions: List[Tuple[int, int]] = [
|
||||
(-1, 0), (1, 0), (0, -1), (0, 1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user