diff --git a/data/src/auto.ts b/data/src/auto.ts index c581a26..84976e1 100644 --- a/data/src/auto.ts +++ b/data/src/auto.ts @@ -2,6 +2,7 @@ import { writeFile } from 'fs/promises'; import { autoLabelTowers } from './auto/auto'; import { IAutoLabelConfig, TowerColor } from './auto/types'; import { GinkaDataset, GinkaTrainData } from './types'; +import { normalizeHeatmap } from './auto/heatmap'; const [, , output, towerInfo, ...folders] = process.argv; @@ -328,11 +329,15 @@ const labelConfig: IAutoLabelConfig = { const data: GinkaTrainData = { map: floor.data.map, size: [width, height], - tag: Array(64).fill(0), + heatmap: [ + normalizeHeatmap(info.wallHeatmap), + normalizeHeatmap(info.enemyHeatmap), + normalizeHeatmap(info.resourceHeatmap), + normalizeHeatmap(info.entryHeatmap) + ], val: [ info.globalDensity, info.wallDensity, - 0, info.doorDensity, info.enemyDensity, info.resourceDensity, @@ -340,9 +345,10 @@ const labelConfig: IAutoLabelConfig = { info.potionDensity, info.keyDensity, info.itemDensity, - info.entryCount, - info.specialDoorCount, - info.fishCount, + info.entryCount / width / height, + 0, + 0, + 0, 0, 0, 0 diff --git a/data/src/auto/heatmap.ts b/data/src/auto/heatmap.ts new file mode 100644 index 0000000..2a27844 --- /dev/null +++ b/data/src/auto/heatmap.ts @@ -0,0 +1,84 @@ +/** + * 将地图转换为热力图 + * @param map 地图矩阵 + * @param tokens 计入热力图的图块 + */ +export function generateHeatmap( + map: number[][], + tokens: Set, + kernel: number = 5 +): number[][] { + if (kernel % 2 !== 1) { + throw new Error(`Kernal size must be odd.`); + } + const width = map[0].length; + const height = map.length; + const result: number[][] = Array.from({ length: height }, _ => + Array.from({ length: width }, _ => 0) + ); + const radius = Math.floor(kernel / 2); + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const left = Math.max(0, x - radius); + const right = Math.min(width, x + radius); + const top = Math.max(0, y - radius); + const bottom = Math.min(height, y + radius); + const size = (right - left) * (bottom - top); + let num = 0; + for (let ky = top; ky < bottom; ky++) { + for (let kx = left; kx < right; kx++) { + if (tokens.has(map[ky][kx])) { + num++; + } + } + } + result[y][x] = num / size; + } + } + return result; +} + +/** + * 对热力图实施高斯模糊 + * @param map 热力图 + * @param sigma 标准差 + */ +export function gaussainHeatmap(map: number[][], sigma: number = 1) { + const radius = sigma * 3; + const width = map[0].length; + const height = map.length; + const result: number[][] = Array.from({ length: height }, _ => + Array.from({ length: width }, _ => 0) + ); + const s = sigma ** 2; + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const left = Math.max(0, x - radius); + const right = Math.min(width - 1, x + radius); + const top = Math.max(0, y - radius); + const bottom = Math.min(height - 1, y + radius); + let res = 0; + for (let ky = top; ky < bottom; ky++) { + for (let kx = left; kx < right; kx++) { + const dis = (ky - y) ** 2 + (kx - x) ** 2; + const g = Math.E ** (-dis / (2 * s)) / (2 * Math.PI * s); + res += map[ky][kx] * g; + } + } + result[y][x] = res; + } + } + return result; +} + +/** + * 归一化热力图 + * @param map 热力图 + */ +export function normalizeHeatmap(map: number[][]) { + const max = Math.max(...map.flat()); + const min = Math.min(...map.flat()); + if (max === min) return map; + const d = max - min; + return map.map(line => line.map(v => (v - min) / d)); +} diff --git a/data/src/auto/info.ts b/data/src/auto/info.ts index c5c22bb..5aaa4b5 100644 --- a/data/src/auto/info.ts +++ b/data/src/auto/info.ts @@ -16,6 +16,7 @@ import { wallTiles } from '../shared'; import { NodeType } from '../topology/interface'; +import { gaussainHeatmap, generateHeatmap } from './heatmap'; interface IRawTowerInfo { /** 作者 id */ @@ -227,7 +228,11 @@ export function parseFloorInfo(tower: ITowerInfo, map: number[][]): IFloorInfo { specialDoorCount: count(flattened, specialDoorTiles), fishCount, hasUselessBranch, - wallDensityStd: computeWallDensityStd(map, wallTiles, 5) + wallDensityStd: computeWallDensityStd(map, wallTiles, 5), + wallHeatmap: gaussainHeatmap(generateHeatmap(map, wallTiles)), + enemyHeatmap: gaussainHeatmap(generateHeatmap(map, enemyTiles)), + resourceHeatmap: gaussainHeatmap(generateHeatmap(map, resourceTiles)), + entryHeatmap: gaussainHeatmap(generateHeatmap(map, entryTiles)) }; return floorInfo; diff --git a/data/src/auto/types.ts b/data/src/auto/types.ts index 1ab92ca..1b59a70 100644 --- a/data/src/auto/types.ts +++ b/data/src/auto/types.ts @@ -93,6 +93,15 @@ export interface IFloorInfo { readonly hasUselessBranch: boolean; /** 墙壁密度标准差 */ readonly wallDensityStd: number; + + /** 怪物热力图 */ + readonly enemyHeatmap: number[][]; + /** 资源热力图 */ + readonly resourceHeatmap: number[][]; + /** 入口热力图 */ + readonly entryHeatmap: number[][]; + /** 墙壁热力图 */ + readonly wallHeatmap: number[][]; } export interface IMapBlockConfig { diff --git a/data/src/types.ts b/data/src/types.ts index 5fa3846..9d11c97 100644 --- a/data/src/types.ts +++ b/data/src/types.ts @@ -43,10 +43,11 @@ export interface GinkaConfig extends BaseConfig { } export interface GinkaTrainData { - tag: number[]; + tag?: number[]; val: number[]; map: number[][]; size: [number, number]; + heatmap?: number[][][]; } export interface GinkaDataset { diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 5b8d1ea..fc5f1ba 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -31,17 +31,14 @@ from .transformer.mask import MapMask # 标量值定义: # 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 # 1. 墙体密度,墙壁/地图面积 -# 2. 装饰密度,装饰数量/地图面积 -# 3. 门密度,门数量/地图面积 -# 4. 怪物密度,怪物数量/地图面积 -# 5. 资源密度,资源数量/地图面积 -# 6. 宝石密度,宝石数量/地图面积 -# 7. 血瓶密度,血瓶数量/地图面积 -# 8. 钥匙密度,钥匙数量/地图面积 -# 9. 道具密度,道具数量/地图面积 -# 10. 入口数量 -# 11. 机关门数量 -# 12. 咸鱼门数量(多层咸鱼门只算一个) +# 2. 门密度,门数量/地图面积 +# 3. 怪物密度,怪物数量/地图面积 +# 4. 资源密度,资源数量/地图面积 +# 5. 宝石密度,宝石数量/地图面积 +# 6. 血瓶密度,血瓶数量/地图面积 +# 7. 钥匙密度,钥匙数量/地图面积 +# 8. 道具密度,道具数量/地图面积 +# 9. 入口数量 # 图块定义: # 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶