mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 新的数据处理方式
This commit is contained in:
parent
f8d6160e5f
commit
afde7be592
@ -27,11 +27,12 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
|
||||
"redGem": [27],
|
||||
"blueGem": [28],
|
||||
"greenGem": [29],
|
||||
"item": [47, 49, 50, 53],
|
||||
"yellowGem": [30],
|
||||
"item": [47, 49, 50, 51, 52, 53],
|
||||
"potion": [31, 32, 33, 34],
|
||||
"key": [21, 22, 23],
|
||||
"door": [81, 82, 83, 85],
|
||||
"wall": [1]
|
||||
"wall": [1, 17]
|
||||
},
|
||||
"data": {}
|
||||
}
|
||||
|
||||
8707
data/ginka-train.json
Normal file
8707
data/ginka-train.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,7 @@
|
||||
"minamo": "tsx ./src/minamo.ts",
|
||||
"merge": "tsx ./src/merge.ts",
|
||||
"review": "tsx ./src/review.ts",
|
||||
"gan": "tsx ./src/gan.ts",
|
||||
"eval": "tsx ./src/eval.ts",
|
||||
"test:topo": "tsx ./src/topology/test.ts",
|
||||
"test:vision": "tsx ./src/vision/test.ts"
|
||||
},
|
||||
|
||||
31
data/src/eval.ts
Normal file
31
data/src/eval.ts
Normal file
@ -0,0 +1,31 @@
|
||||
import { readFile, writeFile } from 'fs-extra';
|
||||
import { GinkaDataset } from './types';
|
||||
import { chooseFrom } from './utils';
|
||||
|
||||
const [outputTrain, outputEval, input, ratioStr] = process.argv.slice(2);
|
||||
const ratio = parseFloat(ratioStr);
|
||||
|
||||
(async () => {
|
||||
const data = await readFile(input, 'utf-8');
|
||||
const dataJSON = JSON.parse(data) as GinkaDataset;
|
||||
const keys = Object.keys(dataJSON.data);
|
||||
const length = keys.length;
|
||||
const toEval = chooseFrom(keys, Math.floor(length * ratio));
|
||||
const toTrain = [...new Set(keys).difference(new Set(toEval))];
|
||||
const trainData: GinkaDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: {}
|
||||
};
|
||||
toTrain.forEach(v => {
|
||||
trainData.data[v] = dataJSON.data[v];
|
||||
});
|
||||
const evalData: GinkaDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: {}
|
||||
};
|
||||
toEval.forEach(v => {
|
||||
evalData.data[v] = dataJSON.data[v];
|
||||
});
|
||||
await writeFile(outputTrain, JSON.stringify(trainData), 'utf-8');
|
||||
await writeFile(outputEval, JSON.stringify(evalData), 'utf-8');
|
||||
})();
|
||||
@ -22,6 +22,7 @@ const numMap: Record<number, number> = {
|
||||
};
|
||||
|
||||
export interface Enemy {
|
||||
num: number;
|
||||
hp: number;
|
||||
atk: number;
|
||||
def: number;
|
||||
@ -44,41 +45,7 @@ function convert(
|
||||
clipped.push(row);
|
||||
}
|
||||
|
||||
// 2. 转换怪物
|
||||
const enemySet = new Set<Enemy>();
|
||||
for (let nx = 0; nx < w; nx++) {
|
||||
for (let ny = 0; ny < h; ny++) {
|
||||
const tile = clipped[ny][nx];
|
||||
if (tile === 201 || tile === 202 || tile === 203) continue;
|
||||
const enemy = enemyMap[tile];
|
||||
if (!enemy) continue;
|
||||
enemySet.add(enemy);
|
||||
}
|
||||
}
|
||||
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
|
||||
const maxAttr = Math.max(...attrs);
|
||||
const minAttr = Math.min(...attrs);
|
||||
const delta = maxAttr - minAttr;
|
||||
for (let ny = 0; ny < w; ny++) {
|
||||
for (let nx = 0; nx < h; nx++) {
|
||||
const tile = clipped[ny][nx];
|
||||
if (tile === 201 || tile === 202 || tile === 203) continue;
|
||||
const enemy = enemyMap[tile];
|
||||
if (!enemy) continue;
|
||||
// 替换为弱怪/中怪/强怪
|
||||
const attr = (enemy.atk + enemy.def) * enemy.hp;
|
||||
const ad = attr - minAttr;
|
||||
if (ad < delta / 3) {
|
||||
clipped[ny][nx] = 7;
|
||||
} else if (ad < (delta * 2) / 3) {
|
||||
clipped[ny][nx] = 8;
|
||||
} else {
|
||||
clipped[ny][nx] = 9;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 转换一般图块
|
||||
// 2. 转换一般图块
|
||||
const mapping: Record<number, number> = {};
|
||||
config.mapping.wall.forEach(v => (mapping[v] = 1));
|
||||
config.mapping.key.forEach(v => (mapping[v] = 2));
|
||||
@ -88,12 +55,69 @@ function convert(
|
||||
config.mapping.door.forEach(v => (mapping[v] = 6));
|
||||
config.mapping.item.forEach(v => (mapping[v] = 12));
|
||||
config.mapping.greenGem.forEach(v => (mapping[v] = 13));
|
||||
const yellowGem = new Set(config.mapping.yellowGem);
|
||||
for (let nx = 0; nx < w; nx++) {
|
||||
for (let ny = 0; ny < h; ny++) {
|
||||
const tile = clipped[ny][nx];
|
||||
const enemy = enemyMap[tile];
|
||||
if (yellowGem.has(tile)) {
|
||||
const rand = Math.random();
|
||||
if (rand < 2 / 5) {
|
||||
clipped[ny][nx] = 3;
|
||||
} else if (rand < 4 / 5) {
|
||||
clipped[ny][nx] = 4;
|
||||
} else {
|
||||
clipped[ny][nx] = 13;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (mapping[tile]) clipped[ny][nx] = mapping[tile];
|
||||
else if (numMap[tile]) clipped[ny][nx] = numMap[tile];
|
||||
else clipped[ny][nx] = 0;
|
||||
else if (!enemy) clipped[ny][nx] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 转换怪物
|
||||
const enemySet = new Set<Enemy>();
|
||||
for (let nx = 0; nx < w; nx++) {
|
||||
for (let ny = 0; ny < h; ny++) {
|
||||
const tile = clipped[ny][nx];
|
||||
const enemy = enemyMap[tile];
|
||||
if (!enemy) continue;
|
||||
enemySet.add({ ...enemy, num: tile });
|
||||
}
|
||||
}
|
||||
const enemyArr = [...enemySet];
|
||||
enemyArr.sort((a, b) => a.num - b.num);
|
||||
if (
|
||||
enemyArr.length === 3 &&
|
||||
enemyArr[0].num === 201 &&
|
||||
enemyArr[1].num === 202 &&
|
||||
enemyArr[2].num === 203
|
||||
) {
|
||||
// pass
|
||||
} else {
|
||||
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
|
||||
const maxAttr = Math.max(...attrs);
|
||||
const minAttr = Math.min(...attrs);
|
||||
const delta = maxAttr - minAttr;
|
||||
for (let ny = 0; ny < w; ny++) {
|
||||
for (let nx = 0; nx < h; nx++) {
|
||||
const tile = clipped[ny][nx];
|
||||
if (tile < 32) continue;
|
||||
const enemy = enemyMap[tile];
|
||||
if (!enemy) continue;
|
||||
// 替换为弱怪/中怪/强怪
|
||||
const attr = (enemy.atk + enemy.def) * enemy.hp;
|
||||
const ad = attr - minAttr;
|
||||
if (ad < delta / 3) {
|
||||
clipped[ny][nx] = 7;
|
||||
} else if (ad < (delta * 2) / 3) {
|
||||
clipped[ny][nx] = 8;
|
||||
} else {
|
||||
clipped[ny][nx] = 9;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ export interface GinkaConfig extends BaseConfig {
|
||||
redGem: number[];
|
||||
blueGem: number[];
|
||||
greenGem: number[];
|
||||
yellowGem: number[];
|
||||
item: number[];
|
||||
potion: number[];
|
||||
key: number[];
|
||||
|
||||
@ -109,22 +109,30 @@ export async function getAllFloors(...info: TowerInfo[]) {
|
||||
join(tower.path, 'floors', `${id}.js`),
|
||||
'utf-8'
|
||||
);
|
||||
const data = JSON.parse(
|
||||
floorFile
|
||||
.replaceAll("'", '"')
|
||||
.slice(floorFile.indexOf('=') + 1)
|
||||
);
|
||||
const map = data.map as number[][];
|
||||
// 裁剪地图
|
||||
const { clip } = tower.config;
|
||||
const area = clip.special[id] ?? clip.defaults;
|
||||
try {
|
||||
const data = JSON.parse(
|
||||
floorFile
|
||||
// .replaceAll("'", '"')
|
||||
.slice(floorFile.indexOf('=') + 1)
|
||||
);
|
||||
|
||||
return convertFloor(
|
||||
map,
|
||||
area,
|
||||
tower.config as GinkaConfig,
|
||||
enemyNumMap
|
||||
);
|
||||
const map = data.map as number[][];
|
||||
// 裁剪地图
|
||||
const { clip } = tower.config;
|
||||
const area = clip.special[id] ?? clip.defaults;
|
||||
|
||||
return convertFloor(
|
||||
map,
|
||||
area,
|
||||
tower.config as GinkaConfig,
|
||||
enemyNumMap
|
||||
);
|
||||
} catch (e) {
|
||||
console.log(
|
||||
`Error when processing '${tower.name}' '${id}'`
|
||||
);
|
||||
throw e;
|
||||
}
|
||||
})
|
||||
);
|
||||
})
|
||||
|
||||
@ -32,7 +32,7 @@ def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
|
||||
return res
|
||||
|
||||
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
|
||||
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13]):
|
||||
"""
|
||||
强制地图最外圈像素必须为指定类别(墙或箭头)
|
||||
|
||||
@ -418,7 +418,7 @@ class WGANGinkaLoss:
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -445,7 +445,7 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -469,7 +469,7 @@ class WGANGinkaLoss:
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user