mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
feat: 新的数据处理方式
This commit is contained in:
parent
f8d6160e5f
commit
afde7be592
@ -27,11 +27,12 @@ GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对
|
|||||||
"redGem": [27],
|
"redGem": [27],
|
||||||
"blueGem": [28],
|
"blueGem": [28],
|
||||||
"greenGem": [29],
|
"greenGem": [29],
|
||||||
"item": [47, 49, 50, 53],
|
"yellowGem": [30],
|
||||||
|
"item": [47, 49, 50, 51, 52, 53],
|
||||||
"potion": [31, 32, 33, 34],
|
"potion": [31, 32, 33, 34],
|
||||||
"key": [21, 22, 23],
|
"key": [21, 22, 23],
|
||||||
"door": [81, 82, 83, 85],
|
"door": [81, 82, 83, 85],
|
||||||
"wall": [1]
|
"wall": [1, 17]
|
||||||
},
|
},
|
||||||
"data": {}
|
"data": {}
|
||||||
}
|
}
|
||||||
|
|||||||
8707
data/ginka-train.json
Normal file
8707
data/ginka-train.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,7 @@
|
|||||||
"minamo": "tsx ./src/minamo.ts",
|
"minamo": "tsx ./src/minamo.ts",
|
||||||
"merge": "tsx ./src/merge.ts",
|
"merge": "tsx ./src/merge.ts",
|
||||||
"review": "tsx ./src/review.ts",
|
"review": "tsx ./src/review.ts",
|
||||||
"gan": "tsx ./src/gan.ts",
|
"eval": "tsx ./src/eval.ts",
|
||||||
"test:topo": "tsx ./src/topology/test.ts",
|
"test:topo": "tsx ./src/topology/test.ts",
|
||||||
"test:vision": "tsx ./src/vision/test.ts"
|
"test:vision": "tsx ./src/vision/test.ts"
|
||||||
},
|
},
|
||||||
|
|||||||
31
data/src/eval.ts
Normal file
31
data/src/eval.ts
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import { readFile, writeFile } from 'fs-extra';
|
||||||
|
import { GinkaDataset } from './types';
|
||||||
|
import { chooseFrom } from './utils';
|
||||||
|
|
||||||
|
const [outputTrain, outputEval, input, ratioStr] = process.argv.slice(2);
|
||||||
|
const ratio = parseFloat(ratioStr);
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
const data = await readFile(input, 'utf-8');
|
||||||
|
const dataJSON = JSON.parse(data) as GinkaDataset;
|
||||||
|
const keys = Object.keys(dataJSON.data);
|
||||||
|
const length = keys.length;
|
||||||
|
const toEval = chooseFrom(keys, Math.floor(length * ratio));
|
||||||
|
const toTrain = [...new Set(keys).difference(new Set(toEval))];
|
||||||
|
const trainData: GinkaDataset = {
|
||||||
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
|
data: {}
|
||||||
|
};
|
||||||
|
toTrain.forEach(v => {
|
||||||
|
trainData.data[v] = dataJSON.data[v];
|
||||||
|
});
|
||||||
|
const evalData: GinkaDataset = {
|
||||||
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
|
data: {}
|
||||||
|
};
|
||||||
|
toEval.forEach(v => {
|
||||||
|
evalData.data[v] = dataJSON.data[v];
|
||||||
|
});
|
||||||
|
await writeFile(outputTrain, JSON.stringify(trainData), 'utf-8');
|
||||||
|
await writeFile(outputEval, JSON.stringify(evalData), 'utf-8');
|
||||||
|
})();
|
||||||
@ -22,6 +22,7 @@ const numMap: Record<number, number> = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export interface Enemy {
|
export interface Enemy {
|
||||||
|
num: number;
|
||||||
hp: number;
|
hp: number;
|
||||||
atk: number;
|
atk: number;
|
||||||
def: number;
|
def: number;
|
||||||
@ -44,41 +45,7 @@ function convert(
|
|||||||
clipped.push(row);
|
clipped.push(row);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 转换怪物
|
// 2. 转换一般图块
|
||||||
const enemySet = new Set<Enemy>();
|
|
||||||
for (let nx = 0; nx < w; nx++) {
|
|
||||||
for (let ny = 0; ny < h; ny++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
if (tile === 201 || tile === 202 || tile === 203) continue;
|
|
||||||
const enemy = enemyMap[tile];
|
|
||||||
if (!enemy) continue;
|
|
||||||
enemySet.add(enemy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
|
|
||||||
const maxAttr = Math.max(...attrs);
|
|
||||||
const minAttr = Math.min(...attrs);
|
|
||||||
const delta = maxAttr - minAttr;
|
|
||||||
for (let ny = 0; ny < w; ny++) {
|
|
||||||
for (let nx = 0; nx < h; nx++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
if (tile === 201 || tile === 202 || tile === 203) continue;
|
|
||||||
const enemy = enemyMap[tile];
|
|
||||||
if (!enemy) continue;
|
|
||||||
// 替换为弱怪/中怪/强怪
|
|
||||||
const attr = (enemy.atk + enemy.def) * enemy.hp;
|
|
||||||
const ad = attr - minAttr;
|
|
||||||
if (ad < delta / 3) {
|
|
||||||
clipped[ny][nx] = 7;
|
|
||||||
} else if (ad < (delta * 2) / 3) {
|
|
||||||
clipped[ny][nx] = 8;
|
|
||||||
} else {
|
|
||||||
clipped[ny][nx] = 9;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 转换一般图块
|
|
||||||
const mapping: Record<number, number> = {};
|
const mapping: Record<number, number> = {};
|
||||||
config.mapping.wall.forEach(v => (mapping[v] = 1));
|
config.mapping.wall.forEach(v => (mapping[v] = 1));
|
||||||
config.mapping.key.forEach(v => (mapping[v] = 2));
|
config.mapping.key.forEach(v => (mapping[v] = 2));
|
||||||
@ -88,12 +55,69 @@ function convert(
|
|||||||
config.mapping.door.forEach(v => (mapping[v] = 6));
|
config.mapping.door.forEach(v => (mapping[v] = 6));
|
||||||
config.mapping.item.forEach(v => (mapping[v] = 12));
|
config.mapping.item.forEach(v => (mapping[v] = 12));
|
||||||
config.mapping.greenGem.forEach(v => (mapping[v] = 13));
|
config.mapping.greenGem.forEach(v => (mapping[v] = 13));
|
||||||
|
const yellowGem = new Set(config.mapping.yellowGem);
|
||||||
for (let nx = 0; nx < w; nx++) {
|
for (let nx = 0; nx < w; nx++) {
|
||||||
for (let ny = 0; ny < h; ny++) {
|
for (let ny = 0; ny < h; ny++) {
|
||||||
const tile = clipped[ny][nx];
|
const tile = clipped[ny][nx];
|
||||||
|
const enemy = enemyMap[tile];
|
||||||
|
if (yellowGem.has(tile)) {
|
||||||
|
const rand = Math.random();
|
||||||
|
if (rand < 2 / 5) {
|
||||||
|
clipped[ny][nx] = 3;
|
||||||
|
} else if (rand < 4 / 5) {
|
||||||
|
clipped[ny][nx] = 4;
|
||||||
|
} else {
|
||||||
|
clipped[ny][nx] = 13;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (mapping[tile]) clipped[ny][nx] = mapping[tile];
|
if (mapping[tile]) clipped[ny][nx] = mapping[tile];
|
||||||
else if (numMap[tile]) clipped[ny][nx] = numMap[tile];
|
else if (numMap[tile]) clipped[ny][nx] = numMap[tile];
|
||||||
else clipped[ny][nx] = 0;
|
else if (!enemy) clipped[ny][nx] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 转换怪物
|
||||||
|
const enemySet = new Set<Enemy>();
|
||||||
|
for (let nx = 0; nx < w; nx++) {
|
||||||
|
for (let ny = 0; ny < h; ny++) {
|
||||||
|
const tile = clipped[ny][nx];
|
||||||
|
const enemy = enemyMap[tile];
|
||||||
|
if (!enemy) continue;
|
||||||
|
enemySet.add({ ...enemy, num: tile });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const enemyArr = [...enemySet];
|
||||||
|
enemyArr.sort((a, b) => a.num - b.num);
|
||||||
|
if (
|
||||||
|
enemyArr.length === 3 &&
|
||||||
|
enemyArr[0].num === 201 &&
|
||||||
|
enemyArr[1].num === 202 &&
|
||||||
|
enemyArr[2].num === 203
|
||||||
|
) {
|
||||||
|
// pass
|
||||||
|
} else {
|
||||||
|
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
|
||||||
|
const maxAttr = Math.max(...attrs);
|
||||||
|
const minAttr = Math.min(...attrs);
|
||||||
|
const delta = maxAttr - minAttr;
|
||||||
|
for (let ny = 0; ny < w; ny++) {
|
||||||
|
for (let nx = 0; nx < h; nx++) {
|
||||||
|
const tile = clipped[ny][nx];
|
||||||
|
if (tile < 32) continue;
|
||||||
|
const enemy = enemyMap[tile];
|
||||||
|
if (!enemy) continue;
|
||||||
|
// 替换为弱怪/中怪/强怪
|
||||||
|
const attr = (enemy.atk + enemy.def) * enemy.hp;
|
||||||
|
const ad = attr - minAttr;
|
||||||
|
if (ad < delta / 3) {
|
||||||
|
clipped[ny][nx] = 7;
|
||||||
|
} else if (ad < (delta * 2) / 3) {
|
||||||
|
clipped[ny][nx] = 8;
|
||||||
|
} else {
|
||||||
|
clipped[ny][nx] = 9;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ export interface GinkaConfig extends BaseConfig {
|
|||||||
redGem: number[];
|
redGem: number[];
|
||||||
blueGem: number[];
|
blueGem: number[];
|
||||||
greenGem: number[];
|
greenGem: number[];
|
||||||
|
yellowGem: number[];
|
||||||
item: number[];
|
item: number[];
|
||||||
potion: number[];
|
potion: number[];
|
||||||
key: number[];
|
key: number[];
|
||||||
|
|||||||
@ -109,22 +109,30 @@ export async function getAllFloors(...info: TowerInfo[]) {
|
|||||||
join(tower.path, 'floors', `${id}.js`),
|
join(tower.path, 'floors', `${id}.js`),
|
||||||
'utf-8'
|
'utf-8'
|
||||||
);
|
);
|
||||||
const data = JSON.parse(
|
try {
|
||||||
floorFile
|
const data = JSON.parse(
|
||||||
.replaceAll("'", '"')
|
floorFile
|
||||||
.slice(floorFile.indexOf('=') + 1)
|
// .replaceAll("'", '"')
|
||||||
);
|
.slice(floorFile.indexOf('=') + 1)
|
||||||
const map = data.map as number[][];
|
);
|
||||||
// 裁剪地图
|
|
||||||
const { clip } = tower.config;
|
|
||||||
const area = clip.special[id] ?? clip.defaults;
|
|
||||||
|
|
||||||
return convertFloor(
|
const map = data.map as number[][];
|
||||||
map,
|
// 裁剪地图
|
||||||
area,
|
const { clip } = tower.config;
|
||||||
tower.config as GinkaConfig,
|
const area = clip.special[id] ?? clip.defaults;
|
||||||
enemyNumMap
|
|
||||||
);
|
return convertFloor(
|
||||||
|
map,
|
||||||
|
area,
|
||||||
|
tower.config as GinkaConfig,
|
||||||
|
enemyNumMap
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
console.log(
|
||||||
|
`Error when processing '${tower.name}' '${id}'`
|
||||||
|
);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
|
|||||||
@ -32,7 +32,7 @@ def get_not_allowed(classes: list[int], include_illegal=False):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
|
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13]):
|
||||||
"""
|
"""
|
||||||
强制地图最外圈像素必须为指定类别(墙或箭头)
|
强制地图最外圈像素必须为指定类别(墙或箭头)
|
||||||
|
|
||||||
@ -418,7 +418,7 @@ class WGANGinkaLoss:
|
|||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
constraint_loss = inner_constraint_loss(probs_fake)
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||||
|
|
||||||
@ -445,7 +445,7 @@ class WGANGinkaLoss:
|
|||||||
|
|
||||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
constraint_loss = inner_constraint_loss(probs_fake)
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||||
|
|
||||||
@ -469,7 +469,7 @@ class WGANGinkaLoss:
|
|||||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
constraint_loss = inner_constraint_loss(probs_fake)
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user