diff --git a/.gitignore b/.gitignore index 12c7e29..02bff83 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ node_modules ginka-dataset.json ginka-eval.json minamo-dataset.json -minamo-eval.json \ No newline at end of file +minamo-eval.json +datasets \ No newline at end of file diff --git a/data/src/minamo.ts b/data/src/minamo.ts index 655e115..13a2854 100644 --- a/data/src/minamo.ts +++ b/data/src/minamo.ts @@ -1,11 +1,12 @@ import { writeFile } from 'fs-extra'; -import { FloorData, getAllFloors, parseTowerInfo } from './utils'; +import { FloorData, readOne, getAllFloors, parseTowerInfo } from './utils'; import { compareMap } from './topology/compare'; import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform'; import { directions, tileType } from './topology/graph'; import { calculateVisualSimilarity } from './vision/similarity'; import { BaseConfig } from './types'; import { Presets, SingleBar } from 'cli-progress'; +import { log } from 'console'; interface MinamoConfig extends BaseConfig {} @@ -23,6 +24,9 @@ interface MinamoDataset { } const [output, ...list] = process.argv.slice(2); +// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比 +const assigned = list.at(-1) === 'assigned'; +if (assigned) list.pop(); function chooseFrom(arr: T[], n: number): T[] { const copy = arr.slice(); @@ -33,6 +37,15 @@ function chooseFrom(arr: T[], n: number): T[] { return copy.slice(0, n); } +function chooseN(maxCount: number, n: number) { + return chooseFrom( + Array(maxCount) + .fill(0) + .map((_, i) => i), + n + ); +} + function choosePair(n: number, max: number = 1000) { const totalCount = Math.round((n * (n - 1)) / 2); const count = Math.min(totalCount, max); @@ -204,6 +217,59 @@ function generateSimilarData(id: string, map: number[][]) { return res; } +function generatePair( + data: Record, + id1: string, + id2: string, + map1: number[][], + map2: number[][], + size: [number, number] +) { + const topoSimilarity = compareMap(id1, id2, map1, map2); + const visionSimilarity = calculateVisualSimilarity(map1, map2); + const train: MinamoTrainData = { + map1, + map2, + topoSimilarity, + visionSimilarity, + size: size + }; + data[`${id1}:${id2}`] = train; + // 自身与自身对比的训练集,保证模型对相同地图输出 1 + const self1 = `${id1}:${id1}`; + const self2 = `${id2}:${id2}`; + const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 3)); + if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) { + const selfTrain1: MinamoTrainData = { + map1: map1, + map2: map1, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data[`${id1}:${id1}`] = selfTrain1; + } + if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) { + const selfTrain2: MinamoTrainData = { + map1: map2, + map2: map2, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data[`${id2}:${id2}`] = selfTrain2; + } + // 翻转、旋转训练集 + Object.assign( + data, + Object.fromEntries( + generateTransformData(id1, id2, map1, map2, topoSimilarity) + ) + ); + // 地图微调训练集 + Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1))); +} + function generateDataset( floors: Map, pairs: number[], @@ -226,53 +292,7 @@ function generateDataset( const [w1, h1] = [map1[0].length, map1.length]; const [w2, h2] = [map2[0].length, map2.length]; if (w1 !== w2 || h1 !== h2) return; - const topoSimilarity = compareMap(id1, id2, map1, map2); - const visionSimilarity = calculateVisualSimilarity(map1, map2); - const train: MinamoTrainData = { - map1, - map2, - topoSimilarity, - visionSimilarity, - size: [w1, h1] - }; - data[`${id1}:${id2}`] = train; - // 自身与自身对比的训练集,保证模型对相同地图输出 1 - const self1 = `${id1}:${id1}`; - const self2 = `${id2}:${id2}`; - const selfTrain = chooseFrom( - [self1, self2], - Math.floor(Math.random() * 3) - ); - if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) { - const selfTrain1: MinamoTrainData = { - map1: map1, - map2: map1, - topoSimilarity: 1, - visionSimilarity: 1, - size: [w1, h1] - }; - data[`${id1}:${id1}`] = selfTrain1; - } - if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) { - const selfTrain2: MinamoTrainData = { - map1: map2, - map2: map2, - topoSimilarity: 1, - visionSimilarity: 1, - size: [w1, h1] - }; - data[`${id2}:${id2}`] = selfTrain2; - } - // 翻转、旋转训练集 - Object.assign( - data, - Object.fromEntries( - generateTransformData(id1, id2, map1, map2, topoSimilarity) - ) - ); - // 地图微调训练集 - Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1))); - // Object.assign(data, Object.fromEntries(generateSimilarData(id2, map2))); + generatePair(data, id1, id2, map1, map2, [w1, h1]); progress.update(i + 1); }); @@ -301,13 +321,76 @@ function parseAllData(data: Map): MinamoDataset { return dataset; } -(async () => { - const towers = await Promise.all( - list.map(v => parseTowerInfo(v, 'minamo-config.json')) +function generateAssignedData( + data1: Map, + data2: Map +): MinamoDataset { + const length = data1.size + data2.size; + const totalCount = data1.size * data2.size; + const count1 = Math.min(100, data1.size); + const count2 = Math.min(100, data2.size); + const keys1 = [...data1.keys()]; + const keys2 = [...data2.keys()]; + const choose1 = chooseFrom(keys1, count1); + const choose2 = chooseFrom(keys2, count2); + + const trainData: Record = {}; + + console.log( + `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${ + count1 * count2 + } 个组合` ); - const floors = await getAllFloors(...towers); - const results = parseAllData(floors); - await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); - const size = Object.keys(results.data).length; - console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`); + + const progress = new SingleBar({}, Presets.shades_classic); + progress.start(count1 * count2, 0); + let n = 0; + + for (const key1 of choose1) { + for (const key2 of choose2) { + const { map: map1 } = data1.get(key1)!; + const { map: map2 } = data2.get(key2)!; + if (!map1 || !map2) continue; + const [w1, h1] = [map1[0].length, map1.length]; + const [w2, h2] = [map2[0].length, map2.length]; + if (w1 !== w2 || h1 !== h2) continue; + generatePair(trainData, key1, key2, map1, map2, [w1, h1]); + n++; + progress.update(n); + } + } + + progress.stop(); + + const dataset: MinamoDataset = { + datasetId: Math.floor(Math.random() * 1e12), + data: trainData + }; + + return dataset; +} + +(async () => { + if (!assigned) { + const towers = await Promise.all( + list.map(v => parseTowerInfo(v, 'minamo-config.json')) + ); + const floors = await getAllFloors(...towers); + const results = parseAllData(floors); + await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); + const size = Object.keys(results.data).length; + console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`); + } else { + const [tower1, tower2] = list; + if (!tower1 || !tower2) { + console.log(`⚠️ assigned 模式下必须传入两个塔!`); + return; + } + const data1 = await readOne(tower1); + const data2 = await readOne(tower2); + const results = generateAssignedData(data1, data2); + await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); + const size = Object.keys(results.data).length; + console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`); + } })(); diff --git a/data/src/utils.ts b/data/src/utils.ts index 78e1d44..9d70db9 100644 --- a/data/src/utils.ts +++ b/data/src/utils.ts @@ -144,3 +144,34 @@ export function mergeFloorIds(...info: TowerInfo[]) { }); return ids; } + +export async function readOne(path: string) { + if (path.endsWith('.json')) { + return fromJSON(path); + } else { + return getAllFloors(await parseTowerInfo(path, 'minamo-config.json')); + } +} + +export async function fromJSON(path: string) { + const file = await readFile(path, 'utf-8'); + const data = JSON.parse(file) as Record; + const clip: Record = {}; + const config: BaseConfig = { + clip: { + defaults: [0, 0, 0, 0], + special: clip + } + }; + const name = (Math.random() * 12).toFixed(0); + const floorMap = new Map(); + for (const [key, value] of Object.entries(data)) { + const floorData: FloorData = { + map: value, + id: key, + config + }; + floorMap.set(`${name}:${key}`, floorData); + } + return floorMap; +}