diff --git a/data/package.json b/data/package.json index 143d379..1e74ace 100644 --- a/data/package.json +++ b/data/package.json @@ -8,6 +8,7 @@ "minamo": "tsx ./src/minamo.ts", "merge": "tsx ./src/merge.ts", "review": "tsx ./src/review.ts", + "gan": "tsx ./src/gan.ts", "test:topo": "tsx ./src/topology/test.ts", "test:vision": "tsx ./src/vision/test.ts" }, diff --git a/data/src/gan.ts b/data/src/gan.ts new file mode 100644 index 0000000..05ab1ca --- /dev/null +++ b/data/src/gan.ts @@ -0,0 +1,131 @@ +import { createConnection, Socket } from 'net'; +import { chooseFrom, FloorData, readOne } from './utils'; +import { MinamoTrainData } from './types'; +import { generateTrainData } from './process/minamo'; + +const SOCKET_FILE = '../tmp/ginka_uds'; +const [refer] = process.argv.slice(2); + +let id = 0; + +function readMap(count: number, buffer: Buffer, h: number, w: number) { + const area = w * h; + + const maps: number[][][] = Array.from({ + length: count + }).map(() => { + return Array.from({ length: h }).map(() => { + return Array.from({ length: w }).fill(0); + }); + }); + + buffer.subarray(4).forEach((v, i) => { + const n = Math.floor(i / area); + const y = Math.floor((i % area) / w); + const x = i % w; + maps[n][y][x] = v; + }); + + return maps; +} + +function generateGANData( + keys: string[], + refer: Map, + map: number[][] +) { + const id2 = `$${id++}`; + const toTrain = chooseFrom(keys, 4); + const data = toTrain.map(v => { + const floor = refer.get(v); + if (!floor) return []; + const size1: [number, number] = [floor.map[0].length, floor.map.length]; + const size2: [number, number] = [map[0].length, map.length]; + if (size1[0] !== size2[0] || size1[1] !== size2[1]) return []; + + return generateTrainData(v, id2, floor.map, map, size1); + }); + return data.flat(); +} + +(async () => { + const referTower = await readOne(refer); + const keys = [...referTower.keys()]; + + const client = createConnection(SOCKET_FILE, () => { + console.log(`UDS IPC connected successfully.`); + // 发送四字节数据表示连接成功 + client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00])); + }); + + client.on('data', buffer => { + // 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化 + // 数据通讯 node 输入协议,单位字节: + // 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type. + const count = buffer.readInt16BE(); + if (buffer.length - 4 !== count * 32 * 32) { + client.write(`ERROR: byte length not match.`); + return []; + } + const h = buffer.readInt8(2); + const w = buffer.readInt8(3); + const map = readMap(count, buffer, h, w); + const simData = map.map(v => generateGANData(keys, referTower, v)); + const rc = 0; + const compareData = simData.flat(); + const reviewData: MinamoTrainData[] = []; + + // 数据通讯 node 输出协议,单位字节: + // 2 - Tensor count; 2 - Review count. Review is right behind train data; + // 1*tc - Compare count for every map tensor delivered. + // 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo; + // N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor. + const toSend = Buffer.alloc( + 2 + // Tensor count + 2 + // Review count + count + // Compare count + 2 * (count + rc) + // Similarity data + compareData.length * 1 * h * w + // Compare map + rc * 2 * h * w, // Review map + 0 + ); + let offset = 0; + toSend.writeInt16BE(count); // Tensor count + toSend.writeInt16BE(0, 2); // Review count + offset += 2 + 2; + // Compare count + toSend.set( + simData.map(v => v.length), + offset + ); + offset += count; + // Similarity data + compareData.forEach(v => { + toSend.writeFloatBE(v.visionSimilarity, offset); + offset += 4; + toSend.writeFloatBE(v.topoSimilarity, offset); + offset += 4; + }); + // Compare map + toSend.set( + compareData.map(v => v.map1).flat(2), + offset // Set from Compare map + ); + offset += compareData.length * 1 * h * w; + // Review map + toSend.set( + reviewData.map(v => [v.map1, v.map2]).flat(3), + offset // Set from last chunk + ); + + client.write(toSend); + }); + + client.on('end', () => { + console.log(`Connection lose.`); + }); + + client.on('error', () => { + client.end(); + }); +})(); diff --git a/data/src/ginka.ts b/data/src/ginka.ts index 83897bc..39723c2 100644 --- a/data/src/ginka.ts +++ b/data/src/ginka.ts @@ -1,61 +1,15 @@ import { writeFile } from 'fs-extra'; -import { FloorData, getAllFloors, parseTowerInfo } from './utils'; -import { Presets, SingleBar } from 'cli-progress'; - -interface GinkaConfig { - clip: { - defaults: [number, number, number, number]; - special: Record; - }; - data: Record; -} - -interface GinkaTrainData { - text: string[]; - map: number[][]; - size: [number, number]; -} - -interface GinkaDataset { - datasetId: number; - data: Record; -} +import { getAllFloors, parseTowerInfo } from './utils'; +import { parseGinka } from './process/ginka'; const [output, ...list] = process.argv.slice(2); -function parseAllData(data: Map) { - const resolved: Record = {}; - - const progress = new SingleBar({}, Presets.shades_classic); - progress.start(data.size, 0); - let i = 0; - - data.forEach((floor, key) => { - const config = floor.config as GinkaConfig; - const text = config.data[floor.id] ?? []; - resolved[key] = { - map: floor.map, - size: [floor.map[0].length, floor.map.length], - text: text - }; - i++; - progress.update(i); - }); - - const dataset: GinkaDataset = { - datasetId: Math.floor(Math.random() * 1e12), - data: resolved - }; - - return dataset; -} - (async () => { const towers = await Promise.all( list.map(v => parseTowerInfo(v, 'ginka-config.json')) ); const floors = await getAllFloors(...towers); - const results = parseAllData(floors); + const results = parseGinka(floors); await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); const size = Object.keys(results.data).length; console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`); diff --git a/data/src/minamo.ts b/data/src/minamo.ts index f15218c..b0cbe56 100644 --- a/data/src/minamo.ts +++ b/data/src/minamo.ts @@ -1,32 +1,6 @@ import { writeFile } from 'fs-extra'; -import { - FloorData, - readOne, - getAllFloors, - parseTowerInfo, - chooseFrom -} from './utils'; -import { compareMap } from './topology/compare'; -import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform'; -import { directions, tileType } from './topology/graph'; -import { calculateVisualSimilarity } from './vision/similarity'; -import { BaseConfig } from './types'; -import { Presets, SingleBar } from 'cli-progress'; - -interface MinamoConfig extends BaseConfig {} - -interface MinamoTrainData { - map1: number[][]; - map2: number[][]; - topoSimilarity: number; - visionSimilarity: number; - size: [number, number]; -} - -interface MinamoDataset { - datasetId: number; - data: Record; -} +import { readOne, getAllFloors, parseTowerInfo } from './utils'; +import { generateAssignedData, parseMinamo } from './process/minamo'; const [output, ...list] = process.argv.slice(2); // 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比 @@ -40,348 +14,13 @@ function parseAssigned(arg: string): [number, number] { return [parseInt(a) || 100, parseInt(b) || 100]; } -function chooseN(maxCount: number, n: number) { - return chooseFrom( - Array(maxCount) - .fill(0) - .map((_, i) => i), - n - ); -} - -function choosePair(n: number, max: number = 1000) { - const totalCount = Math.round((n * (n - 1)) / 2); - const count = Math.min(totalCount, max); - const pairs: number[] = []; - for (let i = 0; i < n; i++) { - for (let j = i + 1; j < n; j++) { - pairs.push(i * n + j); - } - } - // 直接打乱后取前 count 个 - for (let i = pairs.length - 1; i > 0; i--) { - let randIndex = Math.floor(Math.random() * (i + 1)); - [pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]]; - } - - return pairs.slice(0, count); -} - -function transform(map: number[][], rot: number, flip: number) { - let res = map; - for (let i = 0; i < rot; i++) { - res = rotateMap(res); - } - if (flip & 0b01) { - res = mirrorMapX(res); - } - if (flip & 0b10) { - res = mirrorMapY(res); - } - return res; -} - -function generateTransformData( - id1: string, - id2: string, - map1: number[][], - map2: number[][], - simi: number -) { - const types: [rot: number, flip: number][] = []; - for (const rot of [0, 1, 2, 3]) { - for (const flip of [0b00, 0b01, 0b10, 0b11]) { - if (rot === 0 && flip === 0) continue; - types.push([rot, flip]); - } - } - // 随机抽取最多一个 - const trans = chooseFrom(types, Math.floor(Math.random() * 1)); - return trans - .map(([rot, flip]) => { - const com1 = `${id1}.${rot}.${flip}:${id1}`; - const com2 = `${id1}.${rot}.${flip}:${id2}`; - const com3 = `${id2}.${rot}.${flip}:${id1}`; - const com4 = `${id2}.${rot}.${flip}:${id2}`; - const choose = chooseFrom( - [com1, com2, com3, com4], - Math.floor(Math.random() * 2) - ); - const res: [id: string, data: MinamoTrainData][] = []; - if (choose.includes(com1)) { - const t = transform(map1, rot, flip); - res.push([ - com1, - { - map1: t, - map2: map1, - topoSimilarity: 1, - visionSimilarity: calculateVisualSimilarity(map1, t), - size: [map1[0].length, map1.length] - } - ]); - } - if (choose.includes(com2)) { - const t = transform(map1, rot, flip); - res.push([ - com2, - { - map1: t, - map2: map2, - topoSimilarity: simi, - visionSimilarity: calculateVisualSimilarity(t, map2), - size: [map1[0].length, map1.length] - } - ]); - } - if (choose.includes(com3)) { - const t = transform(map2, rot, flip); - res.push([ - com3, - { - map1: t, - map2: map1, - topoSimilarity: simi, - visionSimilarity: calculateVisualSimilarity(t, map1), - size: [map1[0].length, map1.length] - } - ]); - } - if (choose.includes(com4)) { - const t = transform(map2, rot, flip); - res.push([ - com4, - { - map1: t, - map2: map2, - topoSimilarity: 1, - visionSimilarity: calculateVisualSimilarity(t, map2), - size: [map1[0].length, map1.length] - } - ]); - } - - return res; - }) - .flat(); -} - -function generateSimilarData(id: string, map: number[][]) { - // 生成最多两个微调地图 - const width = map[0].length; - const height = map.length; - const num = Math.floor(Math.random() * 2); - const res: [id: string, data: MinamoTrainData][] = []; - - for (let i = 0; i < num; i++) { - const clone = map.map(v => v.slice()); - const prob = Math.random() * 0.3; - for (let ny = 0; ny < height; ny++) { - for (let nx = 0; nx < width; nx++) { - if (Math.random() > prob) { - // 有一定的概率进行微调 - continue; - } - if (Math.random() < 0.2) { - // 20% 概率与旁边图块互换位置 - const [dx, dy] = - directions[ - Math.floor(Math.random() * directions.length) - ]; - const px = nx + dx; - const py = ny + dy; - if (px < 0 || px >= width || py < 0 || py >= height) { - continue; - } - [clone[ny][nx], clone[py][px]] = [ - clone[py][px], - clone[ny][nx] - ]; - } else { - // 80% 概率替换当前图块 - clone[ny][nx] = Math.floor(Math.random() * tileType.size); - } - } - } - const id2 = `${id}.S${i}`; - const sid = `${id}:${id2}`; - const simi = compareMap(id, id2, map, clone); - - res.push([ - sid, - { - map1: map, - map2: clone, - size: [width, height], - topoSimilarity: simi, - visionSimilarity: calculateVisualSimilarity(map, clone) - } - ]); - } - return res; -} - -function generatePair( - data: Record, - id1: string, - id2: string, - map1: number[][], - map2: number[][], - size: [number, number] -) { - const topoSimilarity = compareMap(id1, id2, map1, map2); - const visionSimilarity = calculateVisualSimilarity(map1, map2); - const train: MinamoTrainData = { - map1, - map2, - topoSimilarity, - visionSimilarity, - size: size - }; - data[`${id1}:${id2}`] = train; - // 自身与自身对比的训练集,保证模型对相同地图输出 1 - const self1 = `${id1}:${id1}`; - const self2 = `${id2}:${id2}`; - const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); - if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) { - const selfTrain1: MinamoTrainData = { - map1: map1, - map2: map1, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data[`${id1}:${id1}`] = selfTrain1; - } - if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) { - const selfTrain2: MinamoTrainData = { - map1: map2, - map2: map2, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data[`${id2}:${id2}`] = selfTrain2; - } - // 翻转、旋转训练集 - Object.assign( - data, - Object.fromEntries( - generateTransformData(id1, id2, map1, map2, topoSimilarity) - ) - ); - // 地图微调训练集 - Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1))); -} - -function generateDataset( - floors: Map, - pairs: number[], - floorIds: string[] -): Record { - const data: Record = {}; - - const progress = new SingleBar({}, Presets.shades_classic); - - progress.start(pairs.length, 0); - - pairs.forEach((v, i) => { - const num1 = Math.floor(v / floorIds.length); - const num2 = v % floorIds.length; - const id1 = floorIds[num1]; - const id2 = floorIds[num2]; - const map1 = floors.get(id1)?.map; - const map2 = floors.get(id2)?.map; - if (!map1 || !map2) return; - const [w1, h1] = [map1[0].length, map1.length]; - const [w2, h2] = [map2[0].length, map2.length]; - if (w1 !== w2 || h1 !== h2) return; - generatePair(data, id1, id2, map1, map2, [w1, h1]); - progress.update(i + 1); - }); - - progress.stop(); - - return data; -} - -function parseAllData(data: Map): MinamoDataset { - const length = data.size; - const totalCount = Math.round((length * (length - 1)) / 2); - - const pairs = choosePair(length, 10000); - - console.log( - `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合` - ); - - const trainData = generateDataset(data, pairs, [...data.keys()]); - - const dataset: MinamoDataset = { - datasetId: Math.floor(Math.random() * 1e12), - data: trainData - }; - - return dataset; -} - -function generateAssignedData( - data1: Map, - data2: Map, - count: [number, number] -): MinamoDataset { - const length = data1.size + data2.size; - const totalCount = data1.size * data2.size; - const count1 = Math.min(count[0], data1.size); - const count2 = Math.min(count[1], data2.size); - const keys1 = [...data1.keys()]; - const keys2 = [...data2.keys()]; - const choose1 = chooseFrom(keys1, count1); - - const trainData: Record = {}; - - console.log( - `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${ - count1 * count2 - } 个组合` - ); - - const progress = new SingleBar({}, Presets.shades_classic); - progress.start(count1 * count2, 0); - let n = 0; - - for (const key1 of choose1) { - const choose2 = chooseFrom(keys2, count2); - for (const key2 of choose2) { - const { map: map1 } = data1.get(key1)!; - const { map: map2 } = data2.get(key2)!; - if (!map1 || !map2) continue; - const [w1, h1] = [map1[0].length, map1.length]; - const [w2, h2] = [map2[0].length, map2.length]; - if (w1 !== w2 || h1 !== h2) continue; - generatePair(trainData, key1, key2, map1, map2, [w1, h1]); - n++; - progress.update(n); - } - } - - progress.stop(); - - const dataset: MinamoDataset = { - datasetId: Math.floor(Math.random() * 1e12), - data: trainData - }; - - return dataset; -} - (async () => { if (!assigned) { const towers = await Promise.all( list.map(v => parseTowerInfo(v, 'minamo-config.json')) ); const floors = await getAllFloors(...towers); - const results = parseAllData(floors); + const results = parseMinamo(floors); await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); const size = Object.keys(results.data).length; console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`); diff --git a/data/src/process/ginka.ts b/data/src/process/ginka.ts new file mode 100644 index 0000000..7964fbf --- /dev/null +++ b/data/src/process/ginka.ts @@ -0,0 +1,30 @@ +import { SingleBar, Presets } from 'cli-progress'; +import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types'; +import { FloorData } from 'src/utils'; + +export function parseGinka(data: Map) { + const resolved: Record = {}; + + const progress = new SingleBar({}, Presets.shades_classic); + progress.start(data.size, 0); + let i = 0; + + data.forEach((floor, key) => { + const config = floor.config as GinkaConfig; + const text = config.data[floor.id] ?? []; + resolved[key] = { + map: floor.map, + size: [floor.map[0].length, floor.map.length], + text: text + }; + i++; + progress.update(i); + }); + + const dataset: GinkaDataset = { + datasetId: Math.floor(Math.random() * 1e12), + data: resolved + }; + + return dataset; +} diff --git a/data/src/process/minamo.ts b/data/src/process/minamo.ts new file mode 100644 index 0000000..7edd800 --- /dev/null +++ b/data/src/process/minamo.ts @@ -0,0 +1,395 @@ +import { SingleBar, Presets } from 'cli-progress'; +import { compareMap } from 'src/topology/compare'; +import { directions, tileType } from 'src/topology/graph'; +import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform'; +import { MinamoTrainData, MinamoDataset } from 'src/types'; +import { chooseFrom, FloorData } from 'src/utils'; +import { calculateVisualSimilarity } from 'src/vision/similarity'; + +function chooseN(maxCount: number, n: number) { + return chooseFrom( + Array(maxCount) + .fill(0) + .map((_, i) => i), + n + ); +} + +function choosePair(n: number, max: number = 1000) { + const totalCount = Math.round((n * (n - 1)) / 2); + const count = Math.min(totalCount, max); + const pairs: number[] = []; + for (let i = 0; i < n; i++) { + for (let j = i + 1; j < n; j++) { + pairs.push(i * n + j); + } + } + // 直接打乱后取前 count 个 + for (let i = pairs.length - 1; i > 0; i--) { + let randIndex = Math.floor(Math.random() * (i + 1)); + [pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]]; + } + + return pairs.slice(0, count); +} + +function transform(map: number[][], rot: number, flip: number) { + let res = map; + for (let i = 0; i < rot; i++) { + res = rotateMap(res); + } + if (flip & 0b01) { + res = mirrorMapX(res); + } + if (flip & 0b10) { + res = mirrorMapY(res); + } + return res; +} + +function generateTransformData( + id1: string, + id2: string, + map1: number[][], + map2: number[][], + simi: number +) { + const types: [rot: number, flip: number][] = []; + for (const rot of [0, 1, 2, 3]) { + for (const flip of [0b00, 0b01, 0b10, 0b11]) { + if (rot === 0 && flip === 0) continue; + types.push([rot, flip]); + } + } + // 随机抽取最多一个 + const trans = chooseFrom(types, Math.floor(Math.random() * 1)); + return trans + .map(([rot, flip]) => { + const com1 = `${id1}.${rot}.${flip}:${id1}`; + const com2 = `${id1}.${rot}.${flip}:${id2}`; + const com3 = `${id2}.${rot}.${flip}:${id1}`; + const com4 = `${id2}.${rot}.${flip}:${id2}`; + const choose = chooseFrom( + [com1, com2, com3, com4], + Math.floor(Math.random() * 2) + ); + const res: [id: string, data: MinamoTrainData][] = []; + if (choose.includes(com1)) { + const t = transform(map1, rot, flip); + res.push([ + com1, + { + map1: t, + map2: map1, + topoSimilarity: 1, + visionSimilarity: calculateVisualSimilarity(map1, t), + size: [map1[0].length, map1.length] + } + ]); + } + if (choose.includes(com2)) { + const t = transform(map1, rot, flip); + res.push([ + com2, + { + map1: t, + map2: map2, + topoSimilarity: simi, + visionSimilarity: calculateVisualSimilarity(t, map2), + size: [map1[0].length, map1.length] + } + ]); + } + if (choose.includes(com3)) { + const t = transform(map2, rot, flip); + res.push([ + com3, + { + map1: t, + map2: map1, + topoSimilarity: simi, + visionSimilarity: calculateVisualSimilarity(t, map1), + size: [map1[0].length, map1.length] + } + ]); + } + if (choose.includes(com4)) { + const t = transform(map2, rot, flip); + res.push([ + com4, + { + map1: t, + map2: map2, + topoSimilarity: 1, + visionSimilarity: calculateVisualSimilarity(t, map2), + size: [map1[0].length, map1.length] + } + ]); + } + + return res; + }) + .flat(); +} + +function generateSimilarData(id: string, map: number[][]) { + // 生成最多两个微调地图 + const width = map[0].length; + const height = map.length; + const num = Math.floor(Math.random() * 1); + const res: [id: string, data: MinamoTrainData][] = []; + + for (let i = 0; i < num; i++) { + const clone = map.map(v => v.slice()); + const prob = Math.random() * 0.3; + for (let ny = 0; ny < height; ny++) { + for (let nx = 0; nx < width; nx++) { + if (Math.random() > prob) { + // 有一定的概率进行微调 + continue; + } + if (Math.random() < 0.2) { + // 20% 概率与旁边图块互换位置 + const [dx, dy] = + directions[ + Math.floor(Math.random() * directions.length) + ]; + const px = nx + dx; + const py = ny + dy; + if (px < 0 || px >= width || py < 0 || py >= height) { + continue; + } + [clone[ny][nx], clone[py][px]] = [ + clone[py][px], + clone[ny][nx] + ]; + } else { + // 80% 概率替换当前图块 + clone[ny][nx] = Math.floor(Math.random() * tileType.size); + } + } + } + const id2 = `${id}.S${i}`; + const sid = `${id}:${id2}`; + const simi = compareMap(id, id2, map, clone); + + res.push([ + sid, + { + map1: map, + map2: clone, + size: [width, height], + topoSimilarity: simi, + visionSimilarity: calculateVisualSimilarity(map, clone) + } + ]); + } + return res; +} + +export function generateTrainData( + id1: string, + id2: string, + map1: number[][], + map2: number[][], + size: [number, number] +) { + const topoSimilarity = compareMap(id1, id2, map1, map2); + const visionSimilarity = calculateVisualSimilarity(map1, map2); + const train: MinamoTrainData = { + map1, + map2, + topoSimilarity, + visionSimilarity, + size: size + }; + const data: MinamoTrainData[] = []; + data.push(train); + // 自身与自身对比的训练集,保证模型对相同地图输出 1 + const self1 = `${id1}:${id1}`; + const self2 = `${id2}:${id2}`; + const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); + if (selfTrain.includes(self1)) { + const selfTrain1: MinamoTrainData = { + map1: map1, + map2: map1, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data.push(selfTrain1); + } + if (selfTrain.includes(self2)) { + const selfTrain2: MinamoTrainData = { + map1: map2, + map2: map2, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data.push(selfTrain2); + } + const transform = generateTransformData( + id1, + id2, + map1, + map2, + topoSimilarity + ); + const similar = generateSimilarData(id1, map1); + return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])]; +} + +export function generatePair( + data: Record, + id1: string, + id2: string, + map1: number[][], + map2: number[][], + size: [number, number] +) { + const topoSimilarity = compareMap(id1, id2, map1, map2); + const visionSimilarity = calculateVisualSimilarity(map1, map2); + const train: MinamoTrainData = { + map1, + map2, + topoSimilarity, + visionSimilarity, + size: size + }; + data[`${id1}:${id2}`] = train; + // 自身与自身对比的训练集,保证模型对相同地图输出 1 + const self1 = `${id1}:${id1}`; + const self2 = `${id2}:${id2}`; + const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); + if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) { + const selfTrain1: MinamoTrainData = { + map1: map1, + map2: map1, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data[`${id1}:${id1}`] = selfTrain1; + } + if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) { + const selfTrain2: MinamoTrainData = { + map1: map2, + map2: map2, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data[`${id2}:${id2}`] = selfTrain2; + } + // 翻转、旋转训练集 + Object.assign( + data, + Object.fromEntries( + generateTransformData(id1, id2, map1, map2, topoSimilarity) + ) + ); + // 地图微调训练集 + Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1))); +} + +function generateDataset( + floors: Map, + pairs: number[], + floorIds: string[] +): Record { + const data: Record = {}; + + const progress = new SingleBar({}, Presets.shades_classic); + + progress.start(pairs.length, 0); + + pairs.forEach((v, i) => { + const num1 = Math.floor(v / floorIds.length); + const num2 = v % floorIds.length; + const id1 = floorIds[num1]; + const id2 = floorIds[num2]; + const map1 = floors.get(id1)?.map; + const map2 = floors.get(id2)?.map; + if (!map1 || !map2) return; + const [w1, h1] = [map1[0].length, map1.length]; + const [w2, h2] = [map2[0].length, map2.length]; + if (w1 !== w2 || h1 !== h2) return; + generatePair(data, id1, id2, map1, map2, [w1, h1]); + progress.update(i + 1); + }); + + progress.stop(); + + return data; +} + +export function parseMinamo(data: Map): MinamoDataset { + const length = data.size; + const totalCount = Math.round((length * (length - 1)) / 2); + + const pairs = choosePair(length, 10000); + + console.log( + `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合` + ); + + const trainData = generateDataset(data, pairs, [...data.keys()]); + + const dataset: MinamoDataset = { + datasetId: Math.floor(Math.random() * 1e12), + data: trainData + }; + + return dataset; +} + +export function generateAssignedData( + data1: Map, + data2: Map, + count: [number, number] +): MinamoDataset { + const length = data1.size + data2.size; + const totalCount = data1.size * data2.size; + const count1 = Math.min(count[0], data1.size); + const count2 = Math.min(count[1], data2.size); + const keys1 = [...data1.keys()]; + const keys2 = [...data2.keys()]; + const choose1 = chooseFrom(keys1, count1); + + const trainData: Record = {}; + + console.log( + `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${ + count1 * count2 + } 个组合` + ); + + const progress = new SingleBar({}, Presets.shades_classic); + progress.start(count1 * count2, 0); + let n = 0; + + for (const key1 of choose1) { + const choose2 = chooseFrom(keys2, count2); + for (const key2 of choose2) { + const { map: map1 } = data1.get(key1)!; + const { map: map2 } = data2.get(key2)!; + if (!map1 || !map2) continue; + const [w1, h1] = [map1[0].length, map1.length]; + const [w2, h2] = [map2[0].length, map2.length]; + if (w1 !== w2 || h1 !== h2) continue; + generatePair(trainData, key1, key2, map1, map2, [w1, h1]); + n++; + progress.update(n); + } + } + + progress.stop(); + + const dataset: MinamoDataset = { + datasetId: Math.floor(Math.random() * 1e12), + data: trainData + }; + + return dataset; +} diff --git a/data/src/types.ts b/data/src/types.ts index 3af1d2c..cc8e968 100644 --- a/data/src/types.ts +++ b/data/src/types.ts @@ -11,3 +11,33 @@ export interface TowerInfo { floorIds: string[]; config: BaseConfig; } + +export interface GinkaConfig extends BaseConfig { + data: Record; +} + +export interface GinkaTrainData { + text: string[]; + map: number[][]; + size: [number, number]; +} + +export interface GinkaDataset { + datasetId: number; + data: Record; +} + +export interface MinamoConfig extends BaseConfig {} + +export interface MinamoTrainData { + map1: number[][]; + map2: number[][]; + topoSimilarity: number; + visionSimilarity: number; + size: [number, number]; +} + +export interface MinamoDataset { + datasetId: number; + data: Record; +} diff --git a/ginka/dataset.py b/ginka/dataset.py index 5bc1c36..2c44aa9 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -1,4 +1,5 @@ import json +import random import torch import torch.nn.functional as F from torch.utils.data import Dataset @@ -16,6 +17,12 @@ def load_data(path: str): return data_list +def load_minamo_gan_data(data: list): + res = list() + for one in data: + res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True)) + return res + class GinkaDataset(Dataset): def __init__(self, data_path: str, device, minamo: MinamoModel): self.data = load_data(data_path) # 自定义数据加载函数 @@ -40,4 +47,45 @@ class GinkaDataset(Dataset): "target_topo_feat": topo_feat, "target": target, } - \ No newline at end of file + +class MinamoGANDataset(Dataset): + def __init__(self, refer_data_path): + self.refer = load_minamo_gan_data(load_data(refer_data_path)) + self.data = list().extend(self.refer) + + def set_data(self, data: list): + self.data = data.extend(self.refer) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + map1, map2, vis_sim, topo_sim, review = item + map1 = torch.ShortTensor(map1) + map2 = torch.ShortTensor(map2) + # 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换 + if review: + map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W] + map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W] + + min_main = random.uniform(0.75, 0.9) + max_main = random.uniform(0.9, 1) + epsilon = random.uniform(0, 0.25) + + if review: + map1 = random_smooth_onehot(map1, min_main, max_main, epsilon) + map2 = random_smooth_onehot(map2, min_main, max_main, epsilon) + + graph1 = differentiable_convert_to_data(map1) + graph2 = differentiable_convert_to_data(map2) + + return ( + map1, + map2, + torch.FloatTensor([vis_sim]), + torch.FloatTensor([topo_sim]), + graph1, + graph2 + ) \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index ab7294e..8d72156 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -1,4 +1,5 @@ import math +from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F @@ -253,11 +254,8 @@ class GinkaLoss(nn.Module): minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim minamo_loss = (1.0 - minamo_sim).mean() - print( - minamo_loss.item(), - class_loss.item(), - entrance_loss.item(), - count_loss.item() + tqdm.write( + f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}" ) losses = [ @@ -267,7 +265,4 @@ class GinkaLoss(nn.Module): count_loss * self.weight[3] ] - # 梯度归一化 - scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] - total_loss = sum(scaled_losses) - return total_loss, sum(losses) + return sum(losses) diff --git a/ginka/train.py b/ginka/train.py index 86ff5cb..84eca3f 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -78,7 +78,7 @@ def train(): _, output_softmax = model(feat_vec) # 计算损失 - scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) + losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) # 反向传播 losses.backward() @@ -111,7 +111,7 @@ def train(): print(torch.argmax(output, dim=1)[0]) # 计算损失 - scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) + losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) loss_val += losses.item() avg_val_loss = loss_val / len(dataloader_val) diff --git a/ginka/train_gan.py b/ginka/train_gan.py new file mode 100644 index 0000000..52ed152 --- /dev/null +++ b/ginka/train_gan.py @@ -0,0 +1,305 @@ +import argparse +import socket +import struct +import os +from datetime import datetime +import torch +import torch.optim as optim +import torch.nn.functional as F +from torch_geometric.loader import DataLoader +from tqdm import tqdm +import cv2 +import numpy as np +from .model.model import GinkaModel +from .model.loss import GinkaLoss +from .dataset import GinkaDataset, MinamoGANDataset +from minamo.model.model import MinamoModel +from minamo.model.loss import MinamoLoss +from shared.image import matrix_to_image_cv + +BATCH_SIZE = 32 +EPOCHS_GINKA = 30 +EPOCHS_MINAMO = 10 +SOCKET_PATH = "./tmp/ginka_uds" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +os.makedirs("result", exist_ok=True) +os.makedirs("result/ginka_checkpoint", exist_ok=True) +os.makedirs("tmp", exist_ok=True) + +def parse_arguments(): + parser = argparse.ArgumentParser(description="training codes") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--from_state", type=str, default="result/ginka.pth") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default='ginka-eval.json') + parser.add_argument("--from_cycle", type=int, default=2) + parser.add_argument("--to_cycle", type=int, default=100) + args = parser.parse_args() + return args + +def parse_ginka_batch(batch): + target = batch["target"].to(device) + target_vision_feat = batch["target_vision_feat"].to(device) + target_topo_feat = batch["target_topo_feat"].to(device) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) + + return target, target_vision_feat, target_topo_feat, feat_vec + +def parse_minamo_batch(batch): + map1, map2, vision_simi, topo_simi, graph1, graph2 = batch + map1 = map1.to(device) # 转为 [B, C, H, W] + map2 = map2.to(device) + topo_simi = topo_simi.to(device) + vision_simi = vision_simi.to(device) + graph1 = graph1.to(device) + graph2 = graph2.to(device) + return map1, map2, vision_simi, topo_simi, graph1, graph2 + +def send_all(sock, data): + total_sent = 0 + while total_sent < len(data): + sent = sock.send(data[total_sent:]) + if sent == 0: + raise RuntimeError("Socket connection broken") + total_sent += sent + +def parse_minamo_data(sock: socket.socket, maps: np.ndarray): + # 数据通讯 node 输出协议,单位字节: + # 2 - Tensor count; 2 - Review count. Review is right behind train data; + # 1*tc - Compare count for every map tensor delivered. + # 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo; + # N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor. + _, _, H, W = maps.shape + tc_buf = sock.recv(2) + rc_buf = sock.recv(2) + tc = struct.unpack('>h', tc_buf)[0] + rc = struct.unpack('>h', rc_buf)[0] + count_buf = sock.recv(1 * tc) + count: list = struct.unpack(f">{tc}b", count_buf)[0] + N = sum(count) + sim_buf = sock.recv(2 * 4 * (N + rc)) + com_buf = sock.recv(N * 1 * H * W) + review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes() + + sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0] + com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0] + review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list() + + res = list() + flatten_idx = 0 + # 读取当前这一轮生成器的数据 + for idx in range(N): + com_count = count[idx] + for i in range(com_count): + com_start = flatten_idx * H * W + com_end = (flatten_idx + 1) * H * W + vis_sim = sim[flatten_idx * 2] + topo_sim = sim[flatten_idx * 2 + 1] + com_data = com[com_start:com_end] + flatten_idx += 1 + com_map = np.fromiter(com_data, np.int8).view(H, W) + # map1, map2, vision_similarity, topo_similarity, is_review + res.append((maps[idx], com_map, vis_sim, topo_sim, False)) + + return res + +def train(): + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + + args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json') + + ginka = GinkaModel() + ginka.to(device) + minamo = MinamoModel(32) + minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) + minamo.to(device) + minamo.eval() + + # 准备数据集 + ginka_dataset = GinkaDataset(args.train, device, minamo) + ginka_dataset_val = GinkaDataset(args.validate, device, minamo) + minamo_dataset = MinamoGANDataset() + minamo_dataset_val = MinamoGANDataset() + ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True) + ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True) + minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True) + minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True) + + # 设定优化器与调度器 + optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3) + scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6) + criterion_ginka = GinkaLoss(minamo) + + optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6) + criterion_minamo = MinamoLoss() + + # 用于生成图片 + tile_dict = dict() + for file in os.listdir('tiles'): + name = os.path.splitext(file)[0] + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + + # 与 node 端通讯 + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server.bind(SOCKET_PATH) + server.listen(1) + + if args.resume: + data = torch.load(args.from_state, map_location=device) + ginka.load_state_dict(data["model_state"], strict=False) + if args.load_optim: + optimizer_ginka.load_state_dict(data["optimizer_state"]) + print("Train from loaded state.") + + else: + # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 + criterion_ginka.weight[0] = 0.0 + + for cycle in tqdm(range(args.from_cycle, args.to_cycle)): + # -------------------- 训练生成器 + gen_list: np.ndarray = np.empty(np.int8) + prob_list: np.ndarray = np.empty(np.float32) + for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"): + ginka.train() + minamo.eval() + total_loss = 0 + + # 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来 + if not args.resume and epoch == 10: + criterion_ginka.weight[0] = 0.5 + + for batch in ginka_dataloader: + # 数据迁移到设备 + target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) + # 前向传播 + optimizer_ginka.zero_grad() + _, output_softmax = ginka(feat_vec) + # 计算损失 + losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat) + # 反向传播 + losses.backward() + optimizer_ginka.step() + total_loss += losses.item() + + avg_loss = total_loss / len(ginka_dataloader) + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") + + # 学习率调整 + scheduler_ginka.step() + + if (epoch + 1) % 5 == 0: + loss_val = 0 + ginka.eval() + idx = 0 + with torch.no_grad(): + for batch in ginka_dataloader_val: + target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) + output, output_softmax = ginka(feat_vec) + losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat) + loss_val += losses.item() + if epoch + 1 == EPOCHS_GINKA: + # 最后一次验证的时候顺带生成图片 + prob = output_softmax.cpu().numpy() + np.concatenate((prob_list, prob), axis=1) + map_matrix = torch.argmax(output, dim=1).cpu().numpy() + gen_list = np.concatenate((gen_list, map_matrix), axis=1) + for matrix in map_matrix: + image = matrix_to_image_cv(matrix, tile_dict) + cv2.imwrite(f"result/ginka_img/{idx}.png", image) + idx += 1 + + avg_val_loss = loss_val / len(ginka_dataloader_val) + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") + torch.save({ + "model_state": ginka.state_dict() + }, f"result/ginka_checkpoint/{epoch + 1}.pth") + + tqdm.write(f"Cycle {cycle} Ginka train ended.") + torch.save({ + "model_state": ginka.state_dict() + }, f"result/ginka.pth") + + # -------------------- 生成 Minamo 的训练数据 + + # 数据通讯 python 输出协议,单位字节: + # 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type. + N, H, W = gen_list.shape + gen_bytes = gen_list.astype(np.int8).tobytes() + buf = bytearray() + buf.extend(struct.pack('>h', N)) # Tensor count + buf.extend(struct.pack('>b', H)) # Map height + buf.extend(struct.pack('>b', W)) # Map width + buf.extend(gen_bytes) # Map tensor + server.sendall(buf) + data = parse_minamo_data(server, prob_list) + minamo_dataset.set_data(data) + + # -------------------- 训练判别器 + for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"): + ginka.eval() + minamo.train() + total_loss = 0 + + for batch in minamo_dataloader: + map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch) + + if map1.shape[0] == 1: + continue + + # 前向传播 + optimizer_minamo.zero_grad() + vision_feat1, topo_feat1 = minamo(map1, graph1) + vision_feat2, topo_feat2 = minamo(map2, graph2) + + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + + # 计算损失 + loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi) + + # 反向传播 + loss.backward() + optimizer_minamo.step() + total_loss += loss.item() + + ave_loss = total_loss / len(minamo_dataloader) + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") + + scheduler_minamo.step() + + # 每十轮推理一次验证集 + if (epoch + 1) % 5 == 0: + minamo.eval() + val_loss = 0 + with torch.no_grad(): + for val_batch in tqdm(minamo_dataloader_val, leave=False): + map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch) + + vision_feat1, topo_feat1 = minamo(map1_val, graph1) + vision_feat2, topo_feat2 = minamo(map2_val, graph2) + + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + + # 计算损失 + loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val) + val_loss += loss_val.item() + + avg_val_loss = val_loss / len(minamo_dataloader_val) + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") + torch.save({ + "model_state": minamo.state_dict() + }, f"result/minamo_checkpoint/{epoch + 1}.pth") + + tqdm.write(f"Cycle {cycle} Minamo train ended.") + torch.save({ + "model_state": minamo.state_dict() + }, f"result/ginka.pth") + + print("Train ended.") + +if __name__ == "__main__": + torch.set_num_threads(4) + train() diff --git a/ginka/validate.py b/ginka/validate.py index 89ef941..c82e6d0 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -10,59 +10,11 @@ from minamo.model.model import MinamoModel from .dataset import GinkaDataset from .model.loss import GinkaLoss from .model.model import GinkaModel +from shared.image import matrix_to_image_cv device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs('result/ginka_img', exist_ok=True) -def blend_alpha(bg, fg, alpha): - """ 使用 alpha 通道混合前景图块和背景图 """ - for c in range(3): # 只混合 RGB 三个通道 - bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c] - return bg - -def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): - """ - 使用OpenCV加速的版本(适合大尺寸地图) - :param map_matrix: [H, W] 的numpy数组 - :param tile_set: 字典 {tile_id: cv2图像(BGR格式)} - :param tile_size: 图块边长(像素) - """ - H, W = map_matrix.shape # 获取地图尺寸 - canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景) - - # 遍历地图矩阵 - for row in range(H): - for col in range(W): - tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型 - x, y = col * tile_size, row * tile_size # 计算像素位置 - - # 先绘制地面(0) - if '0' in tile_set: - canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB - - if tile_index == '11': - if row == 0: - tile_index = '11_1' - elif row == W - 1: - tile_index = '11_3' - elif col == 0: - tile_index = '11_2' - elif col == H - 1: - tile_index = '11_4' - - # 叠加其他透明图块 - if tile_index in tile_set and tile_index != 0: - tile_rgba = tile_set[tile_index] - tile_rgb = tile_rgba[:, :, :3] # 提取 RGB - alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha - - # 混合当前图块到背景 - canvas[y:y+tile_size, x:x+tile_size] = blend_alpha( - canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha - ) - - return canvas - def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() diff --git a/shared/image.py b/shared/image.py new file mode 100644 index 0000000..364e6ae --- /dev/null +++ b/shared/image.py @@ -0,0 +1,50 @@ +import numpy as np + +def blend_alpha(bg, fg, alpha): + """ 使用 alpha 通道混合前景图块和背景图 """ + for c in range(3): # 只混合 RGB 三个通道 + bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c] + return bg + +def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): + """ + 使用OpenCV加速的版本(适合大尺寸地图) + :param map_matrix: [H, W] 的numpy数组 + :param tile_set: 字典 {tile_id: cv2图像(BGR格式)} + :param tile_size: 图块边长(像素) + """ + H, W = map_matrix.shape # 获取地图尺寸 + canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景) + + # 遍历地图矩阵 + for row in range(H): + for col in range(W): + tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型 + x, y = col * tile_size, row * tile_size # 计算像素位置 + + # 先绘制地面(0) + if '0' in tile_set: + canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB + + if tile_index == '11': + if row == 0: + tile_index = '11_1' + elif row == W - 1: + tile_index = '11_3' + elif col == 0: + tile_index = '11_2' + elif col == H - 1: + tile_index = '11_4' + + # 叠加其他透明图块 + if tile_index in tile_set and tile_index != 0: + tile_rgba = tile_set[tile_index] + tile_rgb = tile_rgba[:, :, :3] # 提取 RGB + alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha + + # 混合当前图块到背景 + canvas[y:y+tile_size, x:x+tile_size] = blend_alpha( + canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha + ) + + return canvas \ No newline at end of file