mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 10:21:15 +08:00
feat: 使用 IPC 通讯并实现单脚本对抗训练
This commit is contained in:
parent
4721e9a141
commit
8dea79a9f0
@ -8,6 +8,7 @@
|
|||||||
"minamo": "tsx ./src/minamo.ts",
|
"minamo": "tsx ./src/minamo.ts",
|
||||||
"merge": "tsx ./src/merge.ts",
|
"merge": "tsx ./src/merge.ts",
|
||||||
"review": "tsx ./src/review.ts",
|
"review": "tsx ./src/review.ts",
|
||||||
|
"gan": "tsx ./src/gan.ts",
|
||||||
"test:topo": "tsx ./src/topology/test.ts",
|
"test:topo": "tsx ./src/topology/test.ts",
|
||||||
"test:vision": "tsx ./src/vision/test.ts"
|
"test:vision": "tsx ./src/vision/test.ts"
|
||||||
},
|
},
|
||||||
|
|||||||
131
data/src/gan.ts
Normal file
131
data/src/gan.ts
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import { createConnection, Socket } from 'net';
|
||||||
|
import { chooseFrom, FloorData, readOne } from './utils';
|
||||||
|
import { MinamoTrainData } from './types';
|
||||||
|
import { generateTrainData } from './process/minamo';
|
||||||
|
|
||||||
|
const SOCKET_FILE = '../tmp/ginka_uds';
|
||||||
|
const [refer] = process.argv.slice(2);
|
||||||
|
|
||||||
|
let id = 0;
|
||||||
|
|
||||||
|
function readMap(count: number, buffer: Buffer, h: number, w: number) {
|
||||||
|
const area = w * h;
|
||||||
|
|
||||||
|
const maps: number[][][] = Array.from<number[][]>({
|
||||||
|
length: count
|
||||||
|
}).map(() => {
|
||||||
|
return Array.from<number[]>({ length: h }).map(() => {
|
||||||
|
return Array.from<number>({ length: w }).fill(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
buffer.subarray(4).forEach((v, i) => {
|
||||||
|
const n = Math.floor(i / area);
|
||||||
|
const y = Math.floor((i % area) / w);
|
||||||
|
const x = i % w;
|
||||||
|
maps[n][y][x] = v;
|
||||||
|
});
|
||||||
|
|
||||||
|
return maps;
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateGANData(
|
||||||
|
keys: string[],
|
||||||
|
refer: Map<string, FloorData>,
|
||||||
|
map: number[][]
|
||||||
|
) {
|
||||||
|
const id2 = `$${id++}`;
|
||||||
|
const toTrain = chooseFrom(keys, 4);
|
||||||
|
const data = toTrain.map<MinamoTrainData[]>(v => {
|
||||||
|
const floor = refer.get(v);
|
||||||
|
if (!floor) return [];
|
||||||
|
const size1: [number, number] = [floor.map[0].length, floor.map.length];
|
||||||
|
const size2: [number, number] = [map[0].length, map.length];
|
||||||
|
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
|
||||||
|
|
||||||
|
return generateTrainData(v, id2, floor.map, map, size1);
|
||||||
|
});
|
||||||
|
return data.flat();
|
||||||
|
}
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
const referTower = await readOne(refer);
|
||||||
|
const keys = [...referTower.keys()];
|
||||||
|
|
||||||
|
const client = createConnection(SOCKET_FILE, () => {
|
||||||
|
console.log(`UDS IPC connected successfully.`);
|
||||||
|
// 发送四字节数据表示连接成功
|
||||||
|
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
|
||||||
|
});
|
||||||
|
|
||||||
|
client.on('data', buffer => {
|
||||||
|
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化
|
||||||
|
// 数据通讯 node 输入协议,单位字节:
|
||||||
|
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||||
|
const count = buffer.readInt16BE();
|
||||||
|
if (buffer.length - 4 !== count * 32 * 32) {
|
||||||
|
client.write(`ERROR: byte length not match.`);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const h = buffer.readInt8(2);
|
||||||
|
const w = buffer.readInt8(3);
|
||||||
|
const map = readMap(count, buffer, h, w);
|
||||||
|
const simData = map.map(v => generateGANData(keys, referTower, v));
|
||||||
|
const rc = 0;
|
||||||
|
const compareData = simData.flat();
|
||||||
|
const reviewData: MinamoTrainData[] = [];
|
||||||
|
|
||||||
|
// 数据通讯 node 输出协议,单位字节:
|
||||||
|
// 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||||
|
// 1*tc - Compare count for every map tensor delivered.
|
||||||
|
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
||||||
|
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
||||||
|
const toSend = Buffer.alloc(
|
||||||
|
2 + // Tensor count
|
||||||
|
2 + // Review count
|
||||||
|
count + // Compare count
|
||||||
|
2 * (count + rc) + // Similarity data
|
||||||
|
compareData.length * 1 * h * w + // Compare map
|
||||||
|
rc * 2 * h * w, // Review map
|
||||||
|
0
|
||||||
|
);
|
||||||
|
let offset = 0;
|
||||||
|
toSend.writeInt16BE(count); // Tensor count
|
||||||
|
toSend.writeInt16BE(0, 2); // Review count
|
||||||
|
offset += 2 + 2;
|
||||||
|
// Compare count
|
||||||
|
toSend.set(
|
||||||
|
simData.map(v => v.length),
|
||||||
|
offset
|
||||||
|
);
|
||||||
|
offset += count;
|
||||||
|
// Similarity data
|
||||||
|
compareData.forEach(v => {
|
||||||
|
toSend.writeFloatBE(v.visionSimilarity, offset);
|
||||||
|
offset += 4;
|
||||||
|
toSend.writeFloatBE(v.topoSimilarity, offset);
|
||||||
|
offset += 4;
|
||||||
|
});
|
||||||
|
// Compare map
|
||||||
|
toSend.set(
|
||||||
|
compareData.map(v => v.map1).flat(2),
|
||||||
|
offset // Set from Compare map
|
||||||
|
);
|
||||||
|
offset += compareData.length * 1 * h * w;
|
||||||
|
// Review map
|
||||||
|
toSend.set(
|
||||||
|
reviewData.map(v => [v.map1, v.map2]).flat(3),
|
||||||
|
offset // Set from last chunk
|
||||||
|
);
|
||||||
|
|
||||||
|
client.write(toSend);
|
||||||
|
});
|
||||||
|
|
||||||
|
client.on('end', () => {
|
||||||
|
console.log(`Connection lose.`);
|
||||||
|
});
|
||||||
|
|
||||||
|
client.on('error', () => {
|
||||||
|
client.end();
|
||||||
|
});
|
||||||
|
})();
|
||||||
@ -1,61 +1,15 @@
|
|||||||
import { writeFile } from 'fs-extra';
|
import { writeFile } from 'fs-extra';
|
||||||
import { FloorData, getAllFloors, parseTowerInfo } from './utils';
|
import { getAllFloors, parseTowerInfo } from './utils';
|
||||||
import { Presets, SingleBar } from 'cli-progress';
|
import { parseGinka } from './process/ginka';
|
||||||
|
|
||||||
interface GinkaConfig {
|
|
||||||
clip: {
|
|
||||||
defaults: [number, number, number, number];
|
|
||||||
special: Record<string, [number, number, number, number]>;
|
|
||||||
};
|
|
||||||
data: Record<string, string[]>;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface GinkaTrainData {
|
|
||||||
text: string[];
|
|
||||||
map: number[][];
|
|
||||||
size: [number, number];
|
|
||||||
}
|
|
||||||
|
|
||||||
interface GinkaDataset {
|
|
||||||
datasetId: number;
|
|
||||||
data: Record<string, GinkaTrainData>;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [output, ...list] = process.argv.slice(2);
|
const [output, ...list] = process.argv.slice(2);
|
||||||
|
|
||||||
function parseAllData(data: Map<string, FloorData>) {
|
|
||||||
const resolved: Record<string, GinkaTrainData> = {};
|
|
||||||
|
|
||||||
const progress = new SingleBar({}, Presets.shades_classic);
|
|
||||||
progress.start(data.size, 0);
|
|
||||||
let i = 0;
|
|
||||||
|
|
||||||
data.forEach((floor, key) => {
|
|
||||||
const config = floor.config as GinkaConfig;
|
|
||||||
const text = config.data[floor.id] ?? [];
|
|
||||||
resolved[key] = {
|
|
||||||
map: floor.map,
|
|
||||||
size: [floor.map[0].length, floor.map.length],
|
|
||||||
text: text
|
|
||||||
};
|
|
||||||
i++;
|
|
||||||
progress.update(i);
|
|
||||||
});
|
|
||||||
|
|
||||||
const dataset: GinkaDataset = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: resolved
|
|
||||||
};
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
}
|
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
const towers = await Promise.all(
|
const towers = await Promise.all(
|
||||||
list.map(v => parseTowerInfo(v, 'ginka-config.json'))
|
list.map(v => parseTowerInfo(v, 'ginka-config.json'))
|
||||||
);
|
);
|
||||||
const floors = await getAllFloors(...towers);
|
const floors = await getAllFloors(...towers);
|
||||||
const results = parseAllData(floors);
|
const results = parseGinka(floors);
|
||||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
||||||
const size = Object.keys(results.data).length;
|
const size = Object.keys(results.data).length;
|
||||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`);
|
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`);
|
||||||
|
|||||||
@ -1,32 +1,6 @@
|
|||||||
import { writeFile } from 'fs-extra';
|
import { writeFile } from 'fs-extra';
|
||||||
import {
|
import { readOne, getAllFloors, parseTowerInfo } from './utils';
|
||||||
FloorData,
|
import { generateAssignedData, parseMinamo } from './process/minamo';
|
||||||
readOne,
|
|
||||||
getAllFloors,
|
|
||||||
parseTowerInfo,
|
|
||||||
chooseFrom
|
|
||||||
} from './utils';
|
|
||||||
import { compareMap } from './topology/compare';
|
|
||||||
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
|
|
||||||
import { directions, tileType } from './topology/graph';
|
|
||||||
import { calculateVisualSimilarity } from './vision/similarity';
|
|
||||||
import { BaseConfig } from './types';
|
|
||||||
import { Presets, SingleBar } from 'cli-progress';
|
|
||||||
|
|
||||||
interface MinamoConfig extends BaseConfig {}
|
|
||||||
|
|
||||||
interface MinamoTrainData {
|
|
||||||
map1: number[][];
|
|
||||||
map2: number[][];
|
|
||||||
topoSimilarity: number;
|
|
||||||
visionSimilarity: number;
|
|
||||||
size: [number, number];
|
|
||||||
}
|
|
||||||
|
|
||||||
interface MinamoDataset {
|
|
||||||
datasetId: number;
|
|
||||||
data: Record<string, MinamoTrainData>;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [output, ...list] = process.argv.slice(2);
|
const [output, ...list] = process.argv.slice(2);
|
||||||
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
|
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
|
||||||
@ -40,348 +14,13 @@ function parseAssigned(arg: string): [number, number] {
|
|||||||
return [parseInt(a) || 100, parseInt(b) || 100];
|
return [parseInt(a) || 100, parseInt(b) || 100];
|
||||||
}
|
}
|
||||||
|
|
||||||
function chooseN(maxCount: number, n: number) {
|
|
||||||
return chooseFrom(
|
|
||||||
Array(maxCount)
|
|
||||||
.fill(0)
|
|
||||||
.map((_, i) => i),
|
|
||||||
n
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function choosePair(n: number, max: number = 1000) {
|
|
||||||
const totalCount = Math.round((n * (n - 1)) / 2);
|
|
||||||
const count = Math.min(totalCount, max);
|
|
||||||
const pairs: number[] = [];
|
|
||||||
for (let i = 0; i < n; i++) {
|
|
||||||
for (let j = i + 1; j < n; j++) {
|
|
||||||
pairs.push(i * n + j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 直接打乱后取前 count 个
|
|
||||||
for (let i = pairs.length - 1; i > 0; i--) {
|
|
||||||
let randIndex = Math.floor(Math.random() * (i + 1));
|
|
||||||
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
|
|
||||||
}
|
|
||||||
|
|
||||||
return pairs.slice(0, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
function transform(map: number[][], rot: number, flip: number) {
|
|
||||||
let res = map;
|
|
||||||
for (let i = 0; i < rot; i++) {
|
|
||||||
res = rotateMap(res);
|
|
||||||
}
|
|
||||||
if (flip & 0b01) {
|
|
||||||
res = mirrorMapX(res);
|
|
||||||
}
|
|
||||||
if (flip & 0b10) {
|
|
||||||
res = mirrorMapY(res);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateTransformData(
|
|
||||||
id1: string,
|
|
||||||
id2: string,
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
simi: number
|
|
||||||
) {
|
|
||||||
const types: [rot: number, flip: number][] = [];
|
|
||||||
for (const rot of [0, 1, 2, 3]) {
|
|
||||||
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
|
|
||||||
if (rot === 0 && flip === 0) continue;
|
|
||||||
types.push([rot, flip]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 随机抽取最多一个
|
|
||||||
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
|
||||||
return trans
|
|
||||||
.map(([rot, flip]) => {
|
|
||||||
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
|
||||||
const com2 = `${id1}.${rot}.${flip}:${id2}`;
|
|
||||||
const com3 = `${id2}.${rot}.${flip}:${id1}`;
|
|
||||||
const com4 = `${id2}.${rot}.${flip}:${id2}`;
|
|
||||||
const choose = chooseFrom(
|
|
||||||
[com1, com2, com3, com4],
|
|
||||||
Math.floor(Math.random() * 2)
|
|
||||||
);
|
|
||||||
const res: [id: string, data: MinamoTrainData][] = [];
|
|
||||||
if (choose.includes(com1)) {
|
|
||||||
const t = transform(map1, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com1,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(map1, t),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (choose.includes(com2)) {
|
|
||||||
const t = transform(map1, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com2,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: simi,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (choose.includes(com3)) {
|
|
||||||
const t = transform(map2, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com3,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: simi,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(t, map1),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (choose.includes(com4)) {
|
|
||||||
const t = transform(map2, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com4,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
})
|
|
||||||
.flat();
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateSimilarData(id: string, map: number[][]) {
|
|
||||||
// 生成最多两个微调地图
|
|
||||||
const width = map[0].length;
|
|
||||||
const height = map.length;
|
|
||||||
const num = Math.floor(Math.random() * 2);
|
|
||||||
const res: [id: string, data: MinamoTrainData][] = [];
|
|
||||||
|
|
||||||
for (let i = 0; i < num; i++) {
|
|
||||||
const clone = map.map(v => v.slice());
|
|
||||||
const prob = Math.random() * 0.3;
|
|
||||||
for (let ny = 0; ny < height; ny++) {
|
|
||||||
for (let nx = 0; nx < width; nx++) {
|
|
||||||
if (Math.random() > prob) {
|
|
||||||
// 有一定的概率进行微调
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (Math.random() < 0.2) {
|
|
||||||
// 20% 概率与旁边图块互换位置
|
|
||||||
const [dx, dy] =
|
|
||||||
directions[
|
|
||||||
Math.floor(Math.random() * directions.length)
|
|
||||||
];
|
|
||||||
const px = nx + dx;
|
|
||||||
const py = ny + dy;
|
|
||||||
if (px < 0 || px >= width || py < 0 || py >= height) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
[clone[ny][nx], clone[py][px]] = [
|
|
||||||
clone[py][px],
|
|
||||||
clone[ny][nx]
|
|
||||||
];
|
|
||||||
} else {
|
|
||||||
// 80% 概率替换当前图块
|
|
||||||
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const id2 = `${id}.S${i}`;
|
|
||||||
const sid = `${id}:${id2}`;
|
|
||||||
const simi = compareMap(id, id2, map, clone);
|
|
||||||
|
|
||||||
res.push([
|
|
||||||
sid,
|
|
||||||
{
|
|
||||||
map1: map,
|
|
||||||
map2: clone,
|
|
||||||
size: [width, height],
|
|
||||||
topoSimilarity: simi,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(map, clone)
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
function generatePair(
|
|
||||||
data: Record<string, MinamoTrainData>,
|
|
||||||
id1: string,
|
|
||||||
id2: string,
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
size: [number, number]
|
|
||||||
) {
|
|
||||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
|
||||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
|
||||||
const train: MinamoTrainData = {
|
|
||||||
map1,
|
|
||||||
map2,
|
|
||||||
topoSimilarity,
|
|
||||||
visionSimilarity,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data[`${id1}:${id2}`] = train;
|
|
||||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
|
||||||
const self1 = `${id1}:${id1}`;
|
|
||||||
const self2 = `${id2}:${id2}`;
|
|
||||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
|
||||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
|
||||||
const selfTrain1: MinamoTrainData = {
|
|
||||||
map1: map1,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data[`${id1}:${id1}`] = selfTrain1;
|
|
||||||
}
|
|
||||||
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
|
||||||
const selfTrain2: MinamoTrainData = {
|
|
||||||
map1: map2,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data[`${id2}:${id2}`] = selfTrain2;
|
|
||||||
}
|
|
||||||
// 翻转、旋转训练集
|
|
||||||
Object.assign(
|
|
||||||
data,
|
|
||||||
Object.fromEntries(
|
|
||||||
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
// 地图微调训练集
|
|
||||||
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateDataset(
|
|
||||||
floors: Map<string, FloorData>,
|
|
||||||
pairs: number[],
|
|
||||||
floorIds: string[]
|
|
||||||
): Record<string, MinamoTrainData> {
|
|
||||||
const data: Record<string, MinamoTrainData> = {};
|
|
||||||
|
|
||||||
const progress = new SingleBar({}, Presets.shades_classic);
|
|
||||||
|
|
||||||
progress.start(pairs.length, 0);
|
|
||||||
|
|
||||||
pairs.forEach((v, i) => {
|
|
||||||
const num1 = Math.floor(v / floorIds.length);
|
|
||||||
const num2 = v % floorIds.length;
|
|
||||||
const id1 = floorIds[num1];
|
|
||||||
const id2 = floorIds[num2];
|
|
||||||
const map1 = floors.get(id1)?.map;
|
|
||||||
const map2 = floors.get(id2)?.map;
|
|
||||||
if (!map1 || !map2) return;
|
|
||||||
const [w1, h1] = [map1[0].length, map1.length];
|
|
||||||
const [w2, h2] = [map2[0].length, map2.length];
|
|
||||||
if (w1 !== w2 || h1 !== h2) return;
|
|
||||||
generatePair(data, id1, id2, map1, map2, [w1, h1]);
|
|
||||||
progress.update(i + 1);
|
|
||||||
});
|
|
||||||
|
|
||||||
progress.stop();
|
|
||||||
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseAllData(data: Map<string, FloorData>): MinamoDataset {
|
|
||||||
const length = data.size;
|
|
||||||
const totalCount = Math.round((length * (length - 1)) / 2);
|
|
||||||
|
|
||||||
const pairs = choosePair(length, 10000);
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
|
|
||||||
);
|
|
||||||
|
|
||||||
const trainData = generateDataset(data, pairs, [...data.keys()]);
|
|
||||||
|
|
||||||
const dataset: MinamoDataset = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: trainData
|
|
||||||
};
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateAssignedData(
|
|
||||||
data1: Map<string, FloorData>,
|
|
||||||
data2: Map<string, FloorData>,
|
|
||||||
count: [number, number]
|
|
||||||
): MinamoDataset {
|
|
||||||
const length = data1.size + data2.size;
|
|
||||||
const totalCount = data1.size * data2.size;
|
|
||||||
const count1 = Math.min(count[0], data1.size);
|
|
||||||
const count2 = Math.min(count[1], data2.size);
|
|
||||||
const keys1 = [...data1.keys()];
|
|
||||||
const keys2 = [...data2.keys()];
|
|
||||||
const choose1 = chooseFrom(keys1, count1);
|
|
||||||
|
|
||||||
const trainData: Record<string, MinamoTrainData> = {};
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
|
|
||||||
count1 * count2
|
|
||||||
} 个组合`
|
|
||||||
);
|
|
||||||
|
|
||||||
const progress = new SingleBar({}, Presets.shades_classic);
|
|
||||||
progress.start(count1 * count2, 0);
|
|
||||||
let n = 0;
|
|
||||||
|
|
||||||
for (const key1 of choose1) {
|
|
||||||
const choose2 = chooseFrom(keys2, count2);
|
|
||||||
for (const key2 of choose2) {
|
|
||||||
const { map: map1 } = data1.get(key1)!;
|
|
||||||
const { map: map2 } = data2.get(key2)!;
|
|
||||||
if (!map1 || !map2) continue;
|
|
||||||
const [w1, h1] = [map1[0].length, map1.length];
|
|
||||||
const [w2, h2] = [map2[0].length, map2.length];
|
|
||||||
if (w1 !== w2 || h1 !== h2) continue;
|
|
||||||
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
|
|
||||||
n++;
|
|
||||||
progress.update(n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
progress.stop();
|
|
||||||
|
|
||||||
const dataset: MinamoDataset = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: trainData
|
|
||||||
};
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
}
|
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
if (!assigned) {
|
if (!assigned) {
|
||||||
const towers = await Promise.all(
|
const towers = await Promise.all(
|
||||||
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
|
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
|
||||||
);
|
);
|
||||||
const floors = await getAllFloors(...towers);
|
const floors = await getAllFloors(...towers);
|
||||||
const results = parseAllData(floors);
|
const results = parseMinamo(floors);
|
||||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
||||||
const size = Object.keys(results.data).length;
|
const size = Object.keys(results.data).length;
|
||||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
||||||
|
|||||||
30
data/src/process/ginka.ts
Normal file
30
data/src/process/ginka.ts
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import { SingleBar, Presets } from 'cli-progress';
|
||||||
|
import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types';
|
||||||
|
import { FloorData } from 'src/utils';
|
||||||
|
|
||||||
|
export function parseGinka(data: Map<string, FloorData>) {
|
||||||
|
const resolved: Record<string, GinkaTrainData> = {};
|
||||||
|
|
||||||
|
const progress = new SingleBar({}, Presets.shades_classic);
|
||||||
|
progress.start(data.size, 0);
|
||||||
|
let i = 0;
|
||||||
|
|
||||||
|
data.forEach((floor, key) => {
|
||||||
|
const config = floor.config as GinkaConfig;
|
||||||
|
const text = config.data[floor.id] ?? [];
|
||||||
|
resolved[key] = {
|
||||||
|
map: floor.map,
|
||||||
|
size: [floor.map[0].length, floor.map.length],
|
||||||
|
text: text
|
||||||
|
};
|
||||||
|
i++;
|
||||||
|
progress.update(i);
|
||||||
|
});
|
||||||
|
|
||||||
|
const dataset: GinkaDataset = {
|
||||||
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
|
data: resolved
|
||||||
|
};
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
}
|
||||||
395
data/src/process/minamo.ts
Normal file
395
data/src/process/minamo.ts
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
import { SingleBar, Presets } from 'cli-progress';
|
||||||
|
import { compareMap } from 'src/topology/compare';
|
||||||
|
import { directions, tileType } from 'src/topology/graph';
|
||||||
|
import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform';
|
||||||
|
import { MinamoTrainData, MinamoDataset } from 'src/types';
|
||||||
|
import { chooseFrom, FloorData } from 'src/utils';
|
||||||
|
import { calculateVisualSimilarity } from 'src/vision/similarity';
|
||||||
|
|
||||||
|
function chooseN(maxCount: number, n: number) {
|
||||||
|
return chooseFrom(
|
||||||
|
Array(maxCount)
|
||||||
|
.fill(0)
|
||||||
|
.map((_, i) => i),
|
||||||
|
n
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function choosePair(n: number, max: number = 1000) {
|
||||||
|
const totalCount = Math.round((n * (n - 1)) / 2);
|
||||||
|
const count = Math.min(totalCount, max);
|
||||||
|
const pairs: number[] = [];
|
||||||
|
for (let i = 0; i < n; i++) {
|
||||||
|
for (let j = i + 1; j < n; j++) {
|
||||||
|
pairs.push(i * n + j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 直接打乱后取前 count 个
|
||||||
|
for (let i = pairs.length - 1; i > 0; i--) {
|
||||||
|
let randIndex = Math.floor(Math.random() * (i + 1));
|
||||||
|
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
return pairs.slice(0, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
function transform(map: number[][], rot: number, flip: number) {
|
||||||
|
let res = map;
|
||||||
|
for (let i = 0; i < rot; i++) {
|
||||||
|
res = rotateMap(res);
|
||||||
|
}
|
||||||
|
if (flip & 0b01) {
|
||||||
|
res = mirrorMapX(res);
|
||||||
|
}
|
||||||
|
if (flip & 0b10) {
|
||||||
|
res = mirrorMapY(res);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateTransformData(
|
||||||
|
id1: string,
|
||||||
|
id2: string,
|
||||||
|
map1: number[][],
|
||||||
|
map2: number[][],
|
||||||
|
simi: number
|
||||||
|
) {
|
||||||
|
const types: [rot: number, flip: number][] = [];
|
||||||
|
for (const rot of [0, 1, 2, 3]) {
|
||||||
|
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
|
||||||
|
if (rot === 0 && flip === 0) continue;
|
||||||
|
types.push([rot, flip]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 随机抽取最多一个
|
||||||
|
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
||||||
|
return trans
|
||||||
|
.map(([rot, flip]) => {
|
||||||
|
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
||||||
|
const com2 = `${id1}.${rot}.${flip}:${id2}`;
|
||||||
|
const com3 = `${id2}.${rot}.${flip}:${id1}`;
|
||||||
|
const com4 = `${id2}.${rot}.${flip}:${id2}`;
|
||||||
|
const choose = chooseFrom(
|
||||||
|
[com1, com2, com3, com4],
|
||||||
|
Math.floor(Math.random() * 2)
|
||||||
|
);
|
||||||
|
const res: [id: string, data: MinamoTrainData][] = [];
|
||||||
|
if (choose.includes(com1)) {
|
||||||
|
const t = transform(map1, rot, flip);
|
||||||
|
res.push([
|
||||||
|
com1,
|
||||||
|
{
|
||||||
|
map1: t,
|
||||||
|
map2: map1,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: calculateVisualSimilarity(map1, t),
|
||||||
|
size: [map1[0].length, map1.length]
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
if (choose.includes(com2)) {
|
||||||
|
const t = transform(map1, rot, flip);
|
||||||
|
res.push([
|
||||||
|
com2,
|
||||||
|
{
|
||||||
|
map1: t,
|
||||||
|
map2: map2,
|
||||||
|
topoSimilarity: simi,
|
||||||
|
visionSimilarity: calculateVisualSimilarity(t, map2),
|
||||||
|
size: [map1[0].length, map1.length]
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
if (choose.includes(com3)) {
|
||||||
|
const t = transform(map2, rot, flip);
|
||||||
|
res.push([
|
||||||
|
com3,
|
||||||
|
{
|
||||||
|
map1: t,
|
||||||
|
map2: map1,
|
||||||
|
topoSimilarity: simi,
|
||||||
|
visionSimilarity: calculateVisualSimilarity(t, map1),
|
||||||
|
size: [map1[0].length, map1.length]
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
if (choose.includes(com4)) {
|
||||||
|
const t = transform(map2, rot, flip);
|
||||||
|
res.push([
|
||||||
|
com4,
|
||||||
|
{
|
||||||
|
map1: t,
|
||||||
|
map2: map2,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: calculateVisualSimilarity(t, map2),
|
||||||
|
size: [map1[0].length, map1.length]
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
})
|
||||||
|
.flat();
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateSimilarData(id: string, map: number[][]) {
|
||||||
|
// 生成最多两个微调地图
|
||||||
|
const width = map[0].length;
|
||||||
|
const height = map.length;
|
||||||
|
const num = Math.floor(Math.random() * 1);
|
||||||
|
const res: [id: string, data: MinamoTrainData][] = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < num; i++) {
|
||||||
|
const clone = map.map(v => v.slice());
|
||||||
|
const prob = Math.random() * 0.3;
|
||||||
|
for (let ny = 0; ny < height; ny++) {
|
||||||
|
for (let nx = 0; nx < width; nx++) {
|
||||||
|
if (Math.random() > prob) {
|
||||||
|
// 有一定的概率进行微调
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (Math.random() < 0.2) {
|
||||||
|
// 20% 概率与旁边图块互换位置
|
||||||
|
const [dx, dy] =
|
||||||
|
directions[
|
||||||
|
Math.floor(Math.random() * directions.length)
|
||||||
|
];
|
||||||
|
const px = nx + dx;
|
||||||
|
const py = ny + dy;
|
||||||
|
if (px < 0 || px >= width || py < 0 || py >= height) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
[clone[ny][nx], clone[py][px]] = [
|
||||||
|
clone[py][px],
|
||||||
|
clone[ny][nx]
|
||||||
|
];
|
||||||
|
} else {
|
||||||
|
// 80% 概率替换当前图块
|
||||||
|
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const id2 = `${id}.S${i}`;
|
||||||
|
const sid = `${id}:${id2}`;
|
||||||
|
const simi = compareMap(id, id2, map, clone);
|
||||||
|
|
||||||
|
res.push([
|
||||||
|
sid,
|
||||||
|
{
|
||||||
|
map1: map,
|
||||||
|
map2: clone,
|
||||||
|
size: [width, height],
|
||||||
|
topoSimilarity: simi,
|
||||||
|
visionSimilarity: calculateVisualSimilarity(map, clone)
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function generateTrainData(
|
||||||
|
id1: string,
|
||||||
|
id2: string,
|
||||||
|
map1: number[][],
|
||||||
|
map2: number[][],
|
||||||
|
size: [number, number]
|
||||||
|
) {
|
||||||
|
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||||
|
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||||
|
const train: MinamoTrainData = {
|
||||||
|
map1,
|
||||||
|
map2,
|
||||||
|
topoSimilarity,
|
||||||
|
visionSimilarity,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
const data: MinamoTrainData[] = [];
|
||||||
|
data.push(train);
|
||||||
|
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||||
|
const self1 = `${id1}:${id1}`;
|
||||||
|
const self2 = `${id2}:${id2}`;
|
||||||
|
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||||
|
if (selfTrain.includes(self1)) {
|
||||||
|
const selfTrain1: MinamoTrainData = {
|
||||||
|
map1: map1,
|
||||||
|
map2: map1,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: 1,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data.push(selfTrain1);
|
||||||
|
}
|
||||||
|
if (selfTrain.includes(self2)) {
|
||||||
|
const selfTrain2: MinamoTrainData = {
|
||||||
|
map1: map2,
|
||||||
|
map2: map2,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: 1,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data.push(selfTrain2);
|
||||||
|
}
|
||||||
|
const transform = generateTransformData(
|
||||||
|
id1,
|
||||||
|
id2,
|
||||||
|
map1,
|
||||||
|
map2,
|
||||||
|
topoSimilarity
|
||||||
|
);
|
||||||
|
const similar = generateSimilarData(id1, map1);
|
||||||
|
return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function generatePair(
|
||||||
|
data: Record<string, MinamoTrainData>,
|
||||||
|
id1: string,
|
||||||
|
id2: string,
|
||||||
|
map1: number[][],
|
||||||
|
map2: number[][],
|
||||||
|
size: [number, number]
|
||||||
|
) {
|
||||||
|
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||||
|
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||||
|
const train: MinamoTrainData = {
|
||||||
|
map1,
|
||||||
|
map2,
|
||||||
|
topoSimilarity,
|
||||||
|
visionSimilarity,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data[`${id1}:${id2}`] = train;
|
||||||
|
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||||
|
const self1 = `${id1}:${id1}`;
|
||||||
|
const self2 = `${id2}:${id2}`;
|
||||||
|
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||||
|
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
||||||
|
const selfTrain1: MinamoTrainData = {
|
||||||
|
map1: map1,
|
||||||
|
map2: map1,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: 1,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data[`${id1}:${id1}`] = selfTrain1;
|
||||||
|
}
|
||||||
|
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
||||||
|
const selfTrain2: MinamoTrainData = {
|
||||||
|
map1: map2,
|
||||||
|
map2: map2,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: 1,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data[`${id2}:${id2}`] = selfTrain2;
|
||||||
|
}
|
||||||
|
// 翻转、旋转训练集
|
||||||
|
Object.assign(
|
||||||
|
data,
|
||||||
|
Object.fromEntries(
|
||||||
|
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
// 地图微调训练集
|
||||||
|
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateDataset(
|
||||||
|
floors: Map<string, FloorData>,
|
||||||
|
pairs: number[],
|
||||||
|
floorIds: string[]
|
||||||
|
): Record<string, MinamoTrainData> {
|
||||||
|
const data: Record<string, MinamoTrainData> = {};
|
||||||
|
|
||||||
|
const progress = new SingleBar({}, Presets.shades_classic);
|
||||||
|
|
||||||
|
progress.start(pairs.length, 0);
|
||||||
|
|
||||||
|
pairs.forEach((v, i) => {
|
||||||
|
const num1 = Math.floor(v / floorIds.length);
|
||||||
|
const num2 = v % floorIds.length;
|
||||||
|
const id1 = floorIds[num1];
|
||||||
|
const id2 = floorIds[num2];
|
||||||
|
const map1 = floors.get(id1)?.map;
|
||||||
|
const map2 = floors.get(id2)?.map;
|
||||||
|
if (!map1 || !map2) return;
|
||||||
|
const [w1, h1] = [map1[0].length, map1.length];
|
||||||
|
const [w2, h2] = [map2[0].length, map2.length];
|
||||||
|
if (w1 !== w2 || h1 !== h2) return;
|
||||||
|
generatePair(data, id1, id2, map1, map2, [w1, h1]);
|
||||||
|
progress.update(i + 1);
|
||||||
|
});
|
||||||
|
|
||||||
|
progress.stop();
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function parseMinamo(data: Map<string, FloorData>): MinamoDataset {
|
||||||
|
const length = data.size;
|
||||||
|
const totalCount = Math.round((length * (length - 1)) / 2);
|
||||||
|
|
||||||
|
const pairs = choosePair(length, 10000);
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
|
||||||
|
);
|
||||||
|
|
||||||
|
const trainData = generateDataset(data, pairs, [...data.keys()]);
|
||||||
|
|
||||||
|
const dataset: MinamoDataset = {
|
||||||
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
|
data: trainData
|
||||||
|
};
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function generateAssignedData(
|
||||||
|
data1: Map<string, FloorData>,
|
||||||
|
data2: Map<string, FloorData>,
|
||||||
|
count: [number, number]
|
||||||
|
): MinamoDataset {
|
||||||
|
const length = data1.size + data2.size;
|
||||||
|
const totalCount = data1.size * data2.size;
|
||||||
|
const count1 = Math.min(count[0], data1.size);
|
||||||
|
const count2 = Math.min(count[1], data2.size);
|
||||||
|
const keys1 = [...data1.keys()];
|
||||||
|
const keys2 = [...data2.keys()];
|
||||||
|
const choose1 = chooseFrom(keys1, count1);
|
||||||
|
|
||||||
|
const trainData: Record<string, MinamoTrainData> = {};
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
|
||||||
|
count1 * count2
|
||||||
|
} 个组合`
|
||||||
|
);
|
||||||
|
|
||||||
|
const progress = new SingleBar({}, Presets.shades_classic);
|
||||||
|
progress.start(count1 * count2, 0);
|
||||||
|
let n = 0;
|
||||||
|
|
||||||
|
for (const key1 of choose1) {
|
||||||
|
const choose2 = chooseFrom(keys2, count2);
|
||||||
|
for (const key2 of choose2) {
|
||||||
|
const { map: map1 } = data1.get(key1)!;
|
||||||
|
const { map: map2 } = data2.get(key2)!;
|
||||||
|
if (!map1 || !map2) continue;
|
||||||
|
const [w1, h1] = [map1[0].length, map1.length];
|
||||||
|
const [w2, h2] = [map2[0].length, map2.length];
|
||||||
|
if (w1 !== w2 || h1 !== h2) continue;
|
||||||
|
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
|
||||||
|
n++;
|
||||||
|
progress.update(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
progress.stop();
|
||||||
|
|
||||||
|
const dataset: MinamoDataset = {
|
||||||
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
|
data: trainData
|
||||||
|
};
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
}
|
||||||
@ -11,3 +11,33 @@ export interface TowerInfo {
|
|||||||
floorIds: string[];
|
floorIds: string[];
|
||||||
config: BaseConfig;
|
config: BaseConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface GinkaConfig extends BaseConfig {
|
||||||
|
data: Record<string, string[]>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GinkaTrainData {
|
||||||
|
text: string[];
|
||||||
|
map: number[][];
|
||||||
|
size: [number, number];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GinkaDataset {
|
||||||
|
datasetId: number;
|
||||||
|
data: Record<string, GinkaTrainData>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MinamoConfig extends BaseConfig {}
|
||||||
|
|
||||||
|
export interface MinamoTrainData {
|
||||||
|
map1: number[][];
|
||||||
|
map2: number[][];
|
||||||
|
topoSimilarity: number;
|
||||||
|
visionSimilarity: number;
|
||||||
|
size: [number, number];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MinamoDataset {
|
||||||
|
datasetId: number;
|
||||||
|
data: Record<string, MinamoTrainData>;
|
||||||
|
}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
@ -16,6 +17,12 @@ def load_data(path: str):
|
|||||||
|
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
def load_minamo_gan_data(data: list):
|
||||||
|
res = list()
|
||||||
|
for one in data:
|
||||||
|
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
||||||
|
return res
|
||||||
|
|
||||||
class GinkaDataset(Dataset):
|
class GinkaDataset(Dataset):
|
||||||
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path) # 自定义数据加载函数
|
||||||
@ -40,4 +47,45 @@ class GinkaDataset(Dataset):
|
|||||||
"target_topo_feat": topo_feat,
|
"target_topo_feat": topo_feat,
|
||||||
"target": target,
|
"target": target,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MinamoGANDataset(Dataset):
|
||||||
|
def __init__(self, refer_data_path):
|
||||||
|
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
||||||
|
self.data = list().extend(self.refer)
|
||||||
|
|
||||||
|
def set_data(self, data: list):
|
||||||
|
self.data = data.extend(self.refer)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = self.data[idx]
|
||||||
|
|
||||||
|
map1, map2, vis_sim, topo_sim, review = item
|
||||||
|
map1 = torch.ShortTensor(map1)
|
||||||
|
map2 = torch.ShortTensor(map2)
|
||||||
|
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
||||||
|
if review:
|
||||||
|
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
|
||||||
|
min_main = random.uniform(0.75, 0.9)
|
||||||
|
max_main = random.uniform(0.9, 1)
|
||||||
|
epsilon = random.uniform(0, 0.25)
|
||||||
|
|
||||||
|
if review:
|
||||||
|
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
|
||||||
|
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
|
||||||
|
|
||||||
|
graph1 = differentiable_convert_to_data(map1)
|
||||||
|
graph2 = differentiable_convert_to_data(map2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
map1,
|
||||||
|
map2,
|
||||||
|
torch.FloatTensor([vis_sim]),
|
||||||
|
torch.FloatTensor([topo_sim]),
|
||||||
|
graph1,
|
||||||
|
graph2
|
||||||
|
)
|
||||||
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -253,11 +254,8 @@ class GinkaLoss(nn.Module):
|
|||||||
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
|
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
|
||||||
minamo_loss = (1.0 - minamo_sim).mean()
|
minamo_loss = (1.0 - minamo_sim).mean()
|
||||||
|
|
||||||
print(
|
tqdm.write(
|
||||||
minamo_loss.item(),
|
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}"
|
||||||
class_loss.item(),
|
|
||||||
entrance_loss.item(),
|
|
||||||
count_loss.item()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
losses = [
|
losses = [
|
||||||
@ -267,7 +265,4 @@ class GinkaLoss(nn.Module):
|
|||||||
count_loss * self.weight[3]
|
count_loss * self.weight[3]
|
||||||
]
|
]
|
||||||
|
|
||||||
# 梯度归一化
|
return sum(losses)
|
||||||
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
|
||||||
total_loss = sum(scaled_losses)
|
|
||||||
return total_loss, sum(losses)
|
|
||||||
|
|||||||
@ -78,7 +78,7 @@ def train():
|
|||||||
_, output_softmax = model(feat_vec)
|
_, output_softmax = model(feat_vec)
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
|
|
||||||
# 反向传播
|
# 反向传播
|
||||||
losses.backward()
|
losses.backward()
|
||||||
@ -111,7 +111,7 @@ def train():
|
|||||||
print(torch.argmax(output, dim=1)[0])
|
print(torch.argmax(output, dim=1)[0])
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
loss_val += losses.item()
|
loss_val += losses.item()
|
||||||
|
|
||||||
avg_val_loss = loss_val / len(dataloader_val)
|
avg_val_loss = loss_val / len(dataloader_val)
|
||||||
|
|||||||
305
ginka/train_gan.py
Normal file
305
ginka/train_gan.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
import argparse
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.loader import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from .model.model import GinkaModel
|
||||||
|
from .model.loss import GinkaLoss
|
||||||
|
from .dataset import GinkaDataset, MinamoGANDataset
|
||||||
|
from minamo.model.model import MinamoModel
|
||||||
|
from minamo.model.loss import MinamoLoss
|
||||||
|
from shared.image import matrix_to_image_cv
|
||||||
|
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
EPOCHS_GINKA = 30
|
||||||
|
EPOCHS_MINAMO = 10
|
||||||
|
SOCKET_PATH = "./tmp/ginka_uds"
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
os.makedirs("result", exist_ok=True)
|
||||||
|
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||||
|
os.makedirs("tmp", exist_ok=True)
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser(description="training codes")
|
||||||
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
|
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
||||||
|
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||||
|
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
||||||
|
parser.add_argument("--from_cycle", type=int, default=2)
|
||||||
|
parser.add_argument("--to_cycle", type=int, default=100)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def parse_ginka_batch(batch):
|
||||||
|
target = batch["target"].to(device)
|
||||||
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||||
|
|
||||||
|
return target, target_vision_feat, target_topo_feat, feat_vec
|
||||||
|
|
||||||
|
def parse_minamo_batch(batch):
|
||||||
|
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||||
|
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||||
|
map2 = map2.to(device)
|
||||||
|
topo_simi = topo_simi.to(device)
|
||||||
|
vision_simi = vision_simi.to(device)
|
||||||
|
graph1 = graph1.to(device)
|
||||||
|
graph2 = graph2.to(device)
|
||||||
|
return map1, map2, vision_simi, topo_simi, graph1, graph2
|
||||||
|
|
||||||
|
def send_all(sock, data):
|
||||||
|
total_sent = 0
|
||||||
|
while total_sent < len(data):
|
||||||
|
sent = sock.send(data[total_sent:])
|
||||||
|
if sent == 0:
|
||||||
|
raise RuntimeError("Socket connection broken")
|
||||||
|
total_sent += sent
|
||||||
|
|
||||||
|
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||||
|
# 数据通讯 node 输出协议,单位字节:
|
||||||
|
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||||
|
# 1*tc - Compare count for every map tensor delivered.
|
||||||
|
# 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
||||||
|
# N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
||||||
|
_, _, H, W = maps.shape
|
||||||
|
tc_buf = sock.recv(2)
|
||||||
|
rc_buf = sock.recv(2)
|
||||||
|
tc = struct.unpack('>h', tc_buf)[0]
|
||||||
|
rc = struct.unpack('>h', rc_buf)[0]
|
||||||
|
count_buf = sock.recv(1 * tc)
|
||||||
|
count: list = struct.unpack(f">{tc}b", count_buf)[0]
|
||||||
|
N = sum(count)
|
||||||
|
sim_buf = sock.recv(2 * 4 * (N + rc))
|
||||||
|
com_buf = sock.recv(N * 1 * H * W)
|
||||||
|
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes()
|
||||||
|
|
||||||
|
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0]
|
||||||
|
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0]
|
||||||
|
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list()
|
||||||
|
|
||||||
|
res = list()
|
||||||
|
flatten_idx = 0
|
||||||
|
# 读取当前这一轮生成器的数据
|
||||||
|
for idx in range(N):
|
||||||
|
com_count = count[idx]
|
||||||
|
for i in range(com_count):
|
||||||
|
com_start = flatten_idx * H * W
|
||||||
|
com_end = (flatten_idx + 1) * H * W
|
||||||
|
vis_sim = sim[flatten_idx * 2]
|
||||||
|
topo_sim = sim[flatten_idx * 2 + 1]
|
||||||
|
com_data = com[com_start:com_end]
|
||||||
|
flatten_idx += 1
|
||||||
|
com_map = np.fromiter(com_data, np.int8).view(H, W)
|
||||||
|
# map1, map2, vision_similarity, topo_similarity, is_review
|
||||||
|
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def train():
|
||||||
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
|
|
||||||
|
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
||||||
|
|
||||||
|
ginka = GinkaModel()
|
||||||
|
ginka.to(device)
|
||||||
|
minamo = MinamoModel(32)
|
||||||
|
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||||
|
minamo.to(device)
|
||||||
|
minamo.eval()
|
||||||
|
|
||||||
|
# 准备数据集
|
||||||
|
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
||||||
|
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||||
|
minamo_dataset = MinamoGANDataset()
|
||||||
|
minamo_dataset_val = MinamoGANDataset()
|
||||||
|
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
|
||||||
|
# 设定优化器与调度器
|
||||||
|
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
|
||||||
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
|
criterion_ginka = GinkaLoss(minamo)
|
||||||
|
|
||||||
|
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
||||||
|
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
|
criterion_minamo = MinamoLoss()
|
||||||
|
|
||||||
|
# 用于生成图片
|
||||||
|
tile_dict = dict()
|
||||||
|
for file in os.listdir('tiles'):
|
||||||
|
name = os.path.splitext(file)[0]
|
||||||
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
# 与 node 端通讯
|
||||||
|
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
server.bind(SOCKET_PATH)
|
||||||
|
server.listen(1)
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
data = torch.load(args.from_state, map_location=device)
|
||||||
|
ginka.load_state_dict(data["model_state"], strict=False)
|
||||||
|
if args.load_optim:
|
||||||
|
optimizer_ginka.load_state_dict(data["optimizer_state"])
|
||||||
|
print("Train from loaded state.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||||
|
criterion_ginka.weight[0] = 0.0
|
||||||
|
|
||||||
|
for cycle in tqdm(range(args.from_cycle, args.to_cycle)):
|
||||||
|
# -------------------- 训练生成器
|
||||||
|
gen_list: np.ndarray = np.empty(np.int8)
|
||||||
|
prob_list: np.ndarray = np.empty(np.float32)
|
||||||
|
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"):
|
||||||
|
ginka.train()
|
||||||
|
minamo.eval()
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
||||||
|
if not args.resume and epoch == 10:
|
||||||
|
criterion_ginka.weight[0] = 0.5
|
||||||
|
|
||||||
|
for batch in ginka_dataloader:
|
||||||
|
# 数据迁移到设备
|
||||||
|
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||||
|
# 前向传播
|
||||||
|
optimizer_ginka.zero_grad()
|
||||||
|
_, output_softmax = ginka(feat_vec)
|
||||||
|
# 计算损失
|
||||||
|
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
|
# 反向传播
|
||||||
|
losses.backward()
|
||||||
|
optimizer_ginka.step()
|
||||||
|
total_loss += losses.item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(ginka_dataloader)
|
||||||
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
|
# 学习率调整
|
||||||
|
scheduler_ginka.step()
|
||||||
|
|
||||||
|
if (epoch + 1) % 5 == 0:
|
||||||
|
loss_val = 0
|
||||||
|
ginka.eval()
|
||||||
|
idx = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in ginka_dataloader_val:
|
||||||
|
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||||
|
output, output_softmax = ginka(feat_vec)
|
||||||
|
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
|
loss_val += losses.item()
|
||||||
|
if epoch + 1 == EPOCHS_GINKA:
|
||||||
|
# 最后一次验证的时候顺带生成图片
|
||||||
|
prob = output_softmax.cpu().numpy()
|
||||||
|
np.concatenate((prob_list, prob), axis=1)
|
||||||
|
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||||
|
gen_list = np.concatenate((gen_list, map_matrix), axis=1)
|
||||||
|
for matrix in map_matrix:
|
||||||
|
image = matrix_to_image_cv(matrix, tile_dict)
|
||||||
|
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
avg_val_loss = loss_val / len(ginka_dataloader_val)
|
||||||
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||||
|
torch.save({
|
||||||
|
"model_state": ginka.state_dict()
|
||||||
|
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||||
|
|
||||||
|
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
||||||
|
torch.save({
|
||||||
|
"model_state": ginka.state_dict()
|
||||||
|
}, f"result/ginka.pth")
|
||||||
|
|
||||||
|
# -------------------- 生成 Minamo 的训练数据
|
||||||
|
|
||||||
|
# 数据通讯 python 输出协议,单位字节:
|
||||||
|
# 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||||
|
N, H, W = gen_list.shape
|
||||||
|
gen_bytes = gen_list.astype(np.int8).tobytes()
|
||||||
|
buf = bytearray()
|
||||||
|
buf.extend(struct.pack('>h', N)) # Tensor count
|
||||||
|
buf.extend(struct.pack('>b', H)) # Map height
|
||||||
|
buf.extend(struct.pack('>b', W)) # Map width
|
||||||
|
buf.extend(gen_bytes) # Map tensor
|
||||||
|
server.sendall(buf)
|
||||||
|
data = parse_minamo_data(server, prob_list)
|
||||||
|
minamo_dataset.set_data(data)
|
||||||
|
|
||||||
|
# -------------------- 训练判别器
|
||||||
|
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
||||||
|
ginka.eval()
|
||||||
|
minamo.train()
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
for batch in minamo_dataloader:
|
||||||
|
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
|
||||||
|
|
||||||
|
if map1.shape[0] == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
optimizer_minamo.zero_grad()
|
||||||
|
vision_feat1, topo_feat1 = minamo(map1, graph1)
|
||||||
|
vision_feat2, topo_feat2 = minamo(map2, graph2)
|
||||||
|
|
||||||
|
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
|
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
loss.backward()
|
||||||
|
optimizer_minamo.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
ave_loss = total_loss / len(minamo_dataloader)
|
||||||
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
|
scheduler_minamo.step()
|
||||||
|
|
||||||
|
# 每十轮推理一次验证集
|
||||||
|
if (epoch + 1) % 5 == 0:
|
||||||
|
minamo.eval()
|
||||||
|
val_loss = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_batch in tqdm(minamo_dataloader_val, leave=False):
|
||||||
|
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
||||||
|
|
||||||
|
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
||||||
|
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
|
||||||
|
|
||||||
|
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
|
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
|
||||||
|
val_loss += loss_val.item()
|
||||||
|
|
||||||
|
avg_val_loss = val_loss / len(minamo_dataloader_val)
|
||||||
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||||
|
torch.save({
|
||||||
|
"model_state": minamo.state_dict()
|
||||||
|
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||||
|
|
||||||
|
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
||||||
|
torch.save({
|
||||||
|
"model_state": minamo.state_dict()
|
||||||
|
}, f"result/ginka.pth")
|
||||||
|
|
||||||
|
print("Train ended.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
train()
|
||||||
@ -10,59 +10,11 @@ from minamo.model.model import MinamoModel
|
|||||||
from .dataset import GinkaDataset
|
from .dataset import GinkaDataset
|
||||||
from .model.loss import GinkaLoss
|
from .model.loss import GinkaLoss
|
||||||
from .model.model import GinkaModel
|
from .model.model import GinkaModel
|
||||||
|
from shared.image import matrix_to_image_cv
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs('result/ginka_img', exist_ok=True)
|
os.makedirs('result/ginka_img', exist_ok=True)
|
||||||
|
|
||||||
def blend_alpha(bg, fg, alpha):
|
|
||||||
""" 使用 alpha 通道混合前景图块和背景图 """
|
|
||||||
for c in range(3): # 只混合 RGB 三个通道
|
|
||||||
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
|
|
||||||
return bg
|
|
||||||
|
|
||||||
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
|
||||||
"""
|
|
||||||
使用OpenCV加速的版本(适合大尺寸地图)
|
|
||||||
:param map_matrix: [H, W] 的numpy数组
|
|
||||||
:param tile_set: 字典 {tile_id: cv2图像(BGR格式)}
|
|
||||||
:param tile_size: 图块边长(像素)
|
|
||||||
"""
|
|
||||||
H, W = map_matrix.shape # 获取地图尺寸
|
|
||||||
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
|
|
||||||
|
|
||||||
# 遍历地图矩阵
|
|
||||||
for row in range(H):
|
|
||||||
for col in range(W):
|
|
||||||
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
|
|
||||||
x, y = col * tile_size, row * tile_size # 计算像素位置
|
|
||||||
|
|
||||||
# 先绘制地面(0)
|
|
||||||
if '0' in tile_set:
|
|
||||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
|
|
||||||
|
|
||||||
if tile_index == '11':
|
|
||||||
if row == 0:
|
|
||||||
tile_index = '11_1'
|
|
||||||
elif row == W - 1:
|
|
||||||
tile_index = '11_3'
|
|
||||||
elif col == 0:
|
|
||||||
tile_index = '11_2'
|
|
||||||
elif col == H - 1:
|
|
||||||
tile_index = '11_4'
|
|
||||||
|
|
||||||
# 叠加其他透明图块
|
|
||||||
if tile_index in tile_set and tile_index != 0:
|
|
||||||
tile_rgba = tile_set[tile_index]
|
|
||||||
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
|
|
||||||
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
|
|
||||||
|
|
||||||
# 混合当前图块到背景
|
|
||||||
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
|
|
||||||
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
|
|
||||||
)
|
|
||||||
|
|
||||||
return canvas
|
|
||||||
|
|
||||||
def validate():
|
def validate():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
|
|||||||
50
shared/image.py
Normal file
50
shared/image.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def blend_alpha(bg, fg, alpha):
|
||||||
|
""" 使用 alpha 通道混合前景图块和背景图 """
|
||||||
|
for c in range(3): # 只混合 RGB 三个通道
|
||||||
|
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
|
||||||
|
return bg
|
||||||
|
|
||||||
|
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||||
|
"""
|
||||||
|
使用OpenCV加速的版本(适合大尺寸地图)
|
||||||
|
:param map_matrix: [H, W] 的numpy数组
|
||||||
|
:param tile_set: 字典 {tile_id: cv2图像(BGR格式)}
|
||||||
|
:param tile_size: 图块边长(像素)
|
||||||
|
"""
|
||||||
|
H, W = map_matrix.shape # 获取地图尺寸
|
||||||
|
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
|
||||||
|
|
||||||
|
# 遍历地图矩阵
|
||||||
|
for row in range(H):
|
||||||
|
for col in range(W):
|
||||||
|
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
|
||||||
|
x, y = col * tile_size, row * tile_size # 计算像素位置
|
||||||
|
|
||||||
|
# 先绘制地面(0)
|
||||||
|
if '0' in tile_set:
|
||||||
|
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
|
||||||
|
|
||||||
|
if tile_index == '11':
|
||||||
|
if row == 0:
|
||||||
|
tile_index = '11_1'
|
||||||
|
elif row == W - 1:
|
||||||
|
tile_index = '11_3'
|
||||||
|
elif col == 0:
|
||||||
|
tile_index = '11_2'
|
||||||
|
elif col == H - 1:
|
||||||
|
tile_index = '11_4'
|
||||||
|
|
||||||
|
# 叠加其他透明图块
|
||||||
|
if tile_index in tile_set and tile_index != 0:
|
||||||
|
tile_rgba = tile_set[tile_index]
|
||||||
|
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
|
||||||
|
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
|
||||||
|
|
||||||
|
# 混合当前图块到背景
|
||||||
|
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
|
||||||
|
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
return canvas
|
||||||
Loading…
Reference in New Issue
Block a user