feat: 新的数据处理方式

This commit is contained in:
unanmed 2025-04-28 22:51:53 +08:00
parent f8d6160e5f
commit afde7be592
8 changed files with 8830 additions and 58 deletions

View File

@ -27,11 +27,12 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
"redGem": [27],
"blueGem": [28],
"greenGem": [29],
"item": [47, 49, 50, 53],
"yellowGem": [30],
"item": [47, 49, 50, 51, 52, 53],
"potion": [31, 32, 33, 34],
"key": [21, 22, 23],
"door": [81, 82, 83, 85],
"wall": [1]
"wall": [1, 17]
},
"data": {}
}

8707
data/ginka-train.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@
"minamo": "tsx ./src/minamo.ts",
"merge": "tsx ./src/merge.ts",
"review": "tsx ./src/review.ts",
"gan": "tsx ./src/gan.ts",
"eval": "tsx ./src/eval.ts",
"test:topo": "tsx ./src/topology/test.ts",
"test:vision": "tsx ./src/vision/test.ts"
},

31
data/src/eval.ts Normal file
View File

@ -0,0 +1,31 @@
import { readFile, writeFile } from 'fs-extra';
import { GinkaDataset } from './types';
import { chooseFrom } from './utils';
const [outputTrain, outputEval, input, ratioStr] = process.argv.slice(2);
const ratio = parseFloat(ratioStr);
(async () => {
const data = await readFile(input, 'utf-8');
const dataJSON = JSON.parse(data) as GinkaDataset;
const keys = Object.keys(dataJSON.data);
const length = keys.length;
const toEval = chooseFrom(keys, Math.floor(length * ratio));
const toTrain = [...new Set(keys).difference(new Set(toEval))];
const trainData: GinkaDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: {}
};
toTrain.forEach(v => {
trainData.data[v] = dataJSON.data[v];
});
const evalData: GinkaDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: {}
};
toEval.forEach(v => {
evalData.data[v] = dataJSON.data[v];
});
await writeFile(outputTrain, JSON.stringify(trainData), 'utf-8');
await writeFile(outputEval, JSON.stringify(evalData), 'utf-8');
})();

View File

@ -22,6 +22,7 @@ const numMap: Record<number, number> = {
};
export interface Enemy {
num: number;
hp: number;
atk: number;
def: number;
@ -44,41 +45,7 @@ function convert(
clipped.push(row);
}
// 2. 转换怪物
const enemySet = new Set<Enemy>();
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
if (tile === 201 || tile === 202 || tile === 203) continue;
const enemy = enemyMap[tile];
if (!enemy) continue;
enemySet.add(enemy);
}
}
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
const maxAttr = Math.max(...attrs);
const minAttr = Math.min(...attrs);
const delta = maxAttr - minAttr;
for (let ny = 0; ny < w; ny++) {
for (let nx = 0; nx < h; nx++) {
const tile = clipped[ny][nx];
if (tile === 201 || tile === 202 || tile === 203) continue;
const enemy = enemyMap[tile];
if (!enemy) continue;
// 替换为弱怪/中怪/强怪
const attr = (enemy.atk + enemy.def) * enemy.hp;
const ad = attr - minAttr;
if (ad < delta / 3) {
clipped[ny][nx] = 7;
} else if (ad < (delta * 2) / 3) {
clipped[ny][nx] = 8;
} else {
clipped[ny][nx] = 9;
}
}
}
// 3. 转换一般图块
// 2. 转换一般图块
const mapping: Record<number, number> = {};
config.mapping.wall.forEach(v => (mapping[v] = 1));
config.mapping.key.forEach(v => (mapping[v] = 2));
@ -88,12 +55,69 @@ function convert(
config.mapping.door.forEach(v => (mapping[v] = 6));
config.mapping.item.forEach(v => (mapping[v] = 12));
config.mapping.greenGem.forEach(v => (mapping[v] = 13));
const yellowGem = new Set(config.mapping.yellowGem);
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
const enemy = enemyMap[tile];
if (yellowGem.has(tile)) {
const rand = Math.random();
if (rand < 2 / 5) {
clipped[ny][nx] = 3;
} else if (rand < 4 / 5) {
clipped[ny][nx] = 4;
} else {
clipped[ny][nx] = 13;
}
continue;
}
if (mapping[tile]) clipped[ny][nx] = mapping[tile];
else if (numMap[tile]) clipped[ny][nx] = numMap[tile];
else clipped[ny][nx] = 0;
else if (!enemy) clipped[ny][nx] = 0;
}
}
// 3. 转换怪物
const enemySet = new Set<Enemy>();
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
const enemy = enemyMap[tile];
if (!enemy) continue;
enemySet.add({ ...enemy, num: tile });
}
}
const enemyArr = [...enemySet];
enemyArr.sort((a, b) => a.num - b.num);
if (
enemyArr.length === 3 &&
enemyArr[0].num === 201 &&
enemyArr[1].num === 202 &&
enemyArr[2].num === 203
) {
// pass
} else {
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
const maxAttr = Math.max(...attrs);
const minAttr = Math.min(...attrs);
const delta = maxAttr - minAttr;
for (let ny = 0; ny < w; ny++) {
for (let nx = 0; nx < h; nx++) {
const tile = clipped[ny][nx];
if (tile < 32) continue;
const enemy = enemyMap[tile];
if (!enemy) continue;
// 替换为弱怪/中怪/强怪
const attr = (enemy.atk + enemy.def) * enemy.hp;
const ad = attr - minAttr;
if (ad < delta / 3) {
clipped[ny][nx] = 7;
} else if (ad < (delta * 2) / 3) {
clipped[ny][nx] = 8;
} else {
clipped[ny][nx] = 9;
}
}
}
}

View File

@ -18,6 +18,7 @@ export interface GinkaConfig extends BaseConfig {
redGem: number[];
blueGem: number[];
greenGem: number[];
yellowGem: number[];
item: number[];
potion: number[];
key: number[];

View File

@ -109,22 +109,30 @@ export async function getAllFloors(...info: TowerInfo[]) {
join(tower.path, 'floors', `${id}.js`),
'utf-8'
);
const data = JSON.parse(
floorFile
.replaceAll("'", '"')
.slice(floorFile.indexOf('=') + 1)
);
const map = data.map as number[][];
// 裁剪地图
const { clip } = tower.config;
const area = clip.special[id] ?? clip.defaults;
try {
const data = JSON.parse(
floorFile
// .replaceAll("'", '"')
.slice(floorFile.indexOf('=') + 1)
);
return convertFloor(
map,
area,
tower.config as GinkaConfig,
enemyNumMap
);
const map = data.map as number[][];
// 裁剪地图
const { clip } = tower.config;
const area = clip.special[id] ?? clip.defaults;
return convertFloor(
map,
area,
tower.config as GinkaConfig,
enemyNumMap
);
} catch (e) {
console.log(
`Error when processing '${tower.name}' '${id}'`
);
throw e;
}
})
);
})

View File

@ -32,7 +32,7 @@ def get_not_allowed(classes: list[int], include_illegal=False):
return res
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13]):
"""
强制地图最外圈像素必须为指定类别墙或箭头
@ -418,7 +418,7 @@ class WGANGinkaLoss:
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
constraint_loss = inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -445,7 +445,7 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
constraint_loss = inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -469,7 +469,7 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
constraint_loss = inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)