mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
feat: 使用 IPC 通讯并实现单脚本对抗训练
This commit is contained in:
parent
4721e9a141
commit
8dea79a9f0
@ -8,6 +8,7 @@
|
||||
"minamo": "tsx ./src/minamo.ts",
|
||||
"merge": "tsx ./src/merge.ts",
|
||||
"review": "tsx ./src/review.ts",
|
||||
"gan": "tsx ./src/gan.ts",
|
||||
"test:topo": "tsx ./src/topology/test.ts",
|
||||
"test:vision": "tsx ./src/vision/test.ts"
|
||||
},
|
||||
|
||||
131
data/src/gan.ts
Normal file
131
data/src/gan.ts
Normal file
@ -0,0 +1,131 @@
|
||||
import { createConnection, Socket } from 'net';
|
||||
import { chooseFrom, FloorData, readOne } from './utils';
|
||||
import { MinamoTrainData } from './types';
|
||||
import { generateTrainData } from './process/minamo';
|
||||
|
||||
const SOCKET_FILE = '../tmp/ginka_uds';
|
||||
const [refer] = process.argv.slice(2);
|
||||
|
||||
let id = 0;
|
||||
|
||||
function readMap(count: number, buffer: Buffer, h: number, w: number) {
|
||||
const area = w * h;
|
||||
|
||||
const maps: number[][][] = Array.from<number[][]>({
|
||||
length: count
|
||||
}).map(() => {
|
||||
return Array.from<number[]>({ length: h }).map(() => {
|
||||
return Array.from<number>({ length: w }).fill(0);
|
||||
});
|
||||
});
|
||||
|
||||
buffer.subarray(4).forEach((v, i) => {
|
||||
const n = Math.floor(i / area);
|
||||
const y = Math.floor((i % area) / w);
|
||||
const x = i % w;
|
||||
maps[n][y][x] = v;
|
||||
});
|
||||
|
||||
return maps;
|
||||
}
|
||||
|
||||
function generateGANData(
|
||||
keys: string[],
|
||||
refer: Map<string, FloorData>,
|
||||
map: number[][]
|
||||
) {
|
||||
const id2 = `$${id++}`;
|
||||
const toTrain = chooseFrom(keys, 4);
|
||||
const data = toTrain.map<MinamoTrainData[]>(v => {
|
||||
const floor = refer.get(v);
|
||||
if (!floor) return [];
|
||||
const size1: [number, number] = [floor.map[0].length, floor.map.length];
|
||||
const size2: [number, number] = [map[0].length, map.length];
|
||||
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
|
||||
|
||||
return generateTrainData(v, id2, floor.map, map, size1);
|
||||
});
|
||||
return data.flat();
|
||||
}
|
||||
|
||||
(async () => {
|
||||
const referTower = await readOne(refer);
|
||||
const keys = [...referTower.keys()];
|
||||
|
||||
const client = createConnection(SOCKET_FILE, () => {
|
||||
console.log(`UDS IPC connected successfully.`);
|
||||
// 发送四字节数据表示连接成功
|
||||
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
|
||||
});
|
||||
|
||||
client.on('data', buffer => {
|
||||
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化
|
||||
// 数据通讯 node 输入协议,单位字节:
|
||||
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||
const count = buffer.readInt16BE();
|
||||
if (buffer.length - 4 !== count * 32 * 32) {
|
||||
client.write(`ERROR: byte length not match.`);
|
||||
return [];
|
||||
}
|
||||
const h = buffer.readInt8(2);
|
||||
const w = buffer.readInt8(3);
|
||||
const map = readMap(count, buffer, h, w);
|
||||
const simData = map.map(v => generateGANData(keys, referTower, v));
|
||||
const rc = 0;
|
||||
const compareData = simData.flat();
|
||||
const reviewData: MinamoTrainData[] = [];
|
||||
|
||||
// 数据通讯 node 输出协议,单位字节:
|
||||
// 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||
// 1*tc - Compare count for every map tensor delivered.
|
||||
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
||||
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
||||
const toSend = Buffer.alloc(
|
||||
2 + // Tensor count
|
||||
2 + // Review count
|
||||
count + // Compare count
|
||||
2 * (count + rc) + // Similarity data
|
||||
compareData.length * 1 * h * w + // Compare map
|
||||
rc * 2 * h * w, // Review map
|
||||
0
|
||||
);
|
||||
let offset = 0;
|
||||
toSend.writeInt16BE(count); // Tensor count
|
||||
toSend.writeInt16BE(0, 2); // Review count
|
||||
offset += 2 + 2;
|
||||
// Compare count
|
||||
toSend.set(
|
||||
simData.map(v => v.length),
|
||||
offset
|
||||
);
|
||||
offset += count;
|
||||
// Similarity data
|
||||
compareData.forEach(v => {
|
||||
toSend.writeFloatBE(v.visionSimilarity, offset);
|
||||
offset += 4;
|
||||
toSend.writeFloatBE(v.topoSimilarity, offset);
|
||||
offset += 4;
|
||||
});
|
||||
// Compare map
|
||||
toSend.set(
|
||||
compareData.map(v => v.map1).flat(2),
|
||||
offset // Set from Compare map
|
||||
);
|
||||
offset += compareData.length * 1 * h * w;
|
||||
// Review map
|
||||
toSend.set(
|
||||
reviewData.map(v => [v.map1, v.map2]).flat(3),
|
||||
offset // Set from last chunk
|
||||
);
|
||||
|
||||
client.write(toSend);
|
||||
});
|
||||
|
||||
client.on('end', () => {
|
||||
console.log(`Connection lose.`);
|
||||
});
|
||||
|
||||
client.on('error', () => {
|
||||
client.end();
|
||||
});
|
||||
})();
|
||||
@ -1,61 +1,15 @@
|
||||
import { writeFile } from 'fs-extra';
|
||||
import { FloorData, getAllFloors, parseTowerInfo } from './utils';
|
||||
import { Presets, SingleBar } from 'cli-progress';
|
||||
|
||||
interface GinkaConfig {
|
||||
clip: {
|
||||
defaults: [number, number, number, number];
|
||||
special: Record<string, [number, number, number, number]>;
|
||||
};
|
||||
data: Record<string, string[]>;
|
||||
}
|
||||
|
||||
interface GinkaTrainData {
|
||||
text: string[];
|
||||
map: number[][];
|
||||
size: [number, number];
|
||||
}
|
||||
|
||||
interface GinkaDataset {
|
||||
datasetId: number;
|
||||
data: Record<string, GinkaTrainData>;
|
||||
}
|
||||
import { getAllFloors, parseTowerInfo } from './utils';
|
||||
import { parseGinka } from './process/ginka';
|
||||
|
||||
const [output, ...list] = process.argv.slice(2);
|
||||
|
||||
function parseAllData(data: Map<string, FloorData>) {
|
||||
const resolved: Record<string, GinkaTrainData> = {};
|
||||
|
||||
const progress = new SingleBar({}, Presets.shades_classic);
|
||||
progress.start(data.size, 0);
|
||||
let i = 0;
|
||||
|
||||
data.forEach((floor, key) => {
|
||||
const config = floor.config as GinkaConfig;
|
||||
const text = config.data[floor.id] ?? [];
|
||||
resolved[key] = {
|
||||
map: floor.map,
|
||||
size: [floor.map[0].length, floor.map.length],
|
||||
text: text
|
||||
};
|
||||
i++;
|
||||
progress.update(i);
|
||||
});
|
||||
|
||||
const dataset: GinkaDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: resolved
|
||||
};
|
||||
|
||||
return dataset;
|
||||
}
|
||||
|
||||
(async () => {
|
||||
const towers = await Promise.all(
|
||||
list.map(v => parseTowerInfo(v, 'ginka-config.json'))
|
||||
);
|
||||
const floors = await getAllFloors(...towers);
|
||||
const results = parseAllData(floors);
|
||||
const results = parseGinka(floors);
|
||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
||||
const size = Object.keys(results.data).length;
|
||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`);
|
||||
|
||||
@ -1,32 +1,6 @@
|
||||
import { writeFile } from 'fs-extra';
|
||||
import {
|
||||
FloorData,
|
||||
readOne,
|
||||
getAllFloors,
|
||||
parseTowerInfo,
|
||||
chooseFrom
|
||||
} from './utils';
|
||||
import { compareMap } from './topology/compare';
|
||||
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
|
||||
import { directions, tileType } from './topology/graph';
|
||||
import { calculateVisualSimilarity } from './vision/similarity';
|
||||
import { BaseConfig } from './types';
|
||||
import { Presets, SingleBar } from 'cli-progress';
|
||||
|
||||
interface MinamoConfig extends BaseConfig {}
|
||||
|
||||
interface MinamoTrainData {
|
||||
map1: number[][];
|
||||
map2: number[][];
|
||||
topoSimilarity: number;
|
||||
visionSimilarity: number;
|
||||
size: [number, number];
|
||||
}
|
||||
|
||||
interface MinamoDataset {
|
||||
datasetId: number;
|
||||
data: Record<string, MinamoTrainData>;
|
||||
}
|
||||
import { readOne, getAllFloors, parseTowerInfo } from './utils';
|
||||
import { generateAssignedData, parseMinamo } from './process/minamo';
|
||||
|
||||
const [output, ...list] = process.argv.slice(2);
|
||||
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
|
||||
@ -40,348 +14,13 @@ function parseAssigned(arg: string): [number, number] {
|
||||
return [parseInt(a) || 100, parseInt(b) || 100];
|
||||
}
|
||||
|
||||
function chooseN(maxCount: number, n: number) {
|
||||
return chooseFrom(
|
||||
Array(maxCount)
|
||||
.fill(0)
|
||||
.map((_, i) => i),
|
||||
n
|
||||
);
|
||||
}
|
||||
|
||||
function choosePair(n: number, max: number = 1000) {
|
||||
const totalCount = Math.round((n * (n - 1)) / 2);
|
||||
const count = Math.min(totalCount, max);
|
||||
const pairs: number[] = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
for (let j = i + 1; j < n; j++) {
|
||||
pairs.push(i * n + j);
|
||||
}
|
||||
}
|
||||
// 直接打乱后取前 count 个
|
||||
for (let i = pairs.length - 1; i > 0; i--) {
|
||||
let randIndex = Math.floor(Math.random() * (i + 1));
|
||||
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
|
||||
}
|
||||
|
||||
return pairs.slice(0, count);
|
||||
}
|
||||
|
||||
function transform(map: number[][], rot: number, flip: number) {
|
||||
let res = map;
|
||||
for (let i = 0; i < rot; i++) {
|
||||
res = rotateMap(res);
|
||||
}
|
||||
if (flip & 0b01) {
|
||||
res = mirrorMapX(res);
|
||||
}
|
||||
if (flip & 0b10) {
|
||||
res = mirrorMapY(res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function generateTransformData(
|
||||
id1: string,
|
||||
id2: string,
|
||||
map1: number[][],
|
||||
map2: number[][],
|
||||
simi: number
|
||||
) {
|
||||
const types: [rot: number, flip: number][] = [];
|
||||
for (const rot of [0, 1, 2, 3]) {
|
||||
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
|
||||
if (rot === 0 && flip === 0) continue;
|
||||
types.push([rot, flip]);
|
||||
}
|
||||
}
|
||||
// 随机抽取最多一个
|
||||
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
||||
return trans
|
||||
.map(([rot, flip]) => {
|
||||
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
||||
const com2 = `${id1}.${rot}.${flip}:${id2}`;
|
||||
const com3 = `${id2}.${rot}.${flip}:${id1}`;
|
||||
const com4 = `${id2}.${rot}.${flip}:${id2}`;
|
||||
const choose = chooseFrom(
|
||||
[com1, com2, com3, com4],
|
||||
Math.floor(Math.random() * 2)
|
||||
);
|
||||
const res: [id: string, data: MinamoTrainData][] = [];
|
||||
if (choose.includes(com1)) {
|
||||
const t = transform(map1, rot, flip);
|
||||
res.push([
|
||||
com1,
|
||||
{
|
||||
map1: t,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: calculateVisualSimilarity(map1, t),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
if (choose.includes(com2)) {
|
||||
const t = transform(map1, rot, flip);
|
||||
res.push([
|
||||
com2,
|
||||
{
|
||||
map1: t,
|
||||
map2: map2,
|
||||
topoSimilarity: simi,
|
||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
if (choose.includes(com3)) {
|
||||
const t = transform(map2, rot, flip);
|
||||
res.push([
|
||||
com3,
|
||||
{
|
||||
map1: t,
|
||||
map2: map1,
|
||||
topoSimilarity: simi,
|
||||
visionSimilarity: calculateVisualSimilarity(t, map1),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
if (choose.includes(com4)) {
|
||||
const t = transform(map2, rot, flip);
|
||||
res.push([
|
||||
com4,
|
||||
{
|
||||
map1: t,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
|
||||
return res;
|
||||
})
|
||||
.flat();
|
||||
}
|
||||
|
||||
function generateSimilarData(id: string, map: number[][]) {
|
||||
// 生成最多两个微调地图
|
||||
const width = map[0].length;
|
||||
const height = map.length;
|
||||
const num = Math.floor(Math.random() * 2);
|
||||
const res: [id: string, data: MinamoTrainData][] = [];
|
||||
|
||||
for (let i = 0; i < num; i++) {
|
||||
const clone = map.map(v => v.slice());
|
||||
const prob = Math.random() * 0.3;
|
||||
for (let ny = 0; ny < height; ny++) {
|
||||
for (let nx = 0; nx < width; nx++) {
|
||||
if (Math.random() > prob) {
|
||||
// 有一定的概率进行微调
|
||||
continue;
|
||||
}
|
||||
if (Math.random() < 0.2) {
|
||||
// 20% 概率与旁边图块互换位置
|
||||
const [dx, dy] =
|
||||
directions[
|
||||
Math.floor(Math.random() * directions.length)
|
||||
];
|
||||
const px = nx + dx;
|
||||
const py = ny + dy;
|
||||
if (px < 0 || px >= width || py < 0 || py >= height) {
|
||||
continue;
|
||||
}
|
||||
[clone[ny][nx], clone[py][px]] = [
|
||||
clone[py][px],
|
||||
clone[ny][nx]
|
||||
];
|
||||
} else {
|
||||
// 80% 概率替换当前图块
|
||||
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
|
||||
}
|
||||
}
|
||||
}
|
||||
const id2 = `${id}.S${i}`;
|
||||
const sid = `${id}:${id2}`;
|
||||
const simi = compareMap(id, id2, map, clone);
|
||||
|
||||
res.push([
|
||||
sid,
|
||||
{
|
||||
map1: map,
|
||||
map2: clone,
|
||||
size: [width, height],
|
||||
topoSimilarity: simi,
|
||||
visionSimilarity: calculateVisualSimilarity(map, clone)
|
||||
}
|
||||
]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function generatePair(
|
||||
data: Record<string, MinamoTrainData>,
|
||||
id1: string,
|
||||
id2: string,
|
||||
map1: number[][],
|
||||
map2: number[][],
|
||||
size: [number, number]
|
||||
) {
|
||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||
const train: MinamoTrainData = {
|
||||
map1,
|
||||
map2,
|
||||
topoSimilarity,
|
||||
visionSimilarity,
|
||||
size: size
|
||||
};
|
||||
data[`${id1}:${id2}`] = train;
|
||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||
const self1 = `${id1}:${id1}`;
|
||||
const self2 = `${id2}:${id2}`;
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
||||
const selfTrain1: MinamoTrainData = {
|
||||
map1: map1,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data[`${id1}:${id1}`] = selfTrain1;
|
||||
}
|
||||
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
||||
const selfTrain2: MinamoTrainData = {
|
||||
map1: map2,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data[`${id2}:${id2}`] = selfTrain2;
|
||||
}
|
||||
// 翻转、旋转训练集
|
||||
Object.assign(
|
||||
data,
|
||||
Object.fromEntries(
|
||||
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
||||
)
|
||||
);
|
||||
// 地图微调训练集
|
||||
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
||||
}
|
||||
|
||||
function generateDataset(
|
||||
floors: Map<string, FloorData>,
|
||||
pairs: number[],
|
||||
floorIds: string[]
|
||||
): Record<string, MinamoTrainData> {
|
||||
const data: Record<string, MinamoTrainData> = {};
|
||||
|
||||
const progress = new SingleBar({}, Presets.shades_classic);
|
||||
|
||||
progress.start(pairs.length, 0);
|
||||
|
||||
pairs.forEach((v, i) => {
|
||||
const num1 = Math.floor(v / floorIds.length);
|
||||
const num2 = v % floorIds.length;
|
||||
const id1 = floorIds[num1];
|
||||
const id2 = floorIds[num2];
|
||||
const map1 = floors.get(id1)?.map;
|
||||
const map2 = floors.get(id2)?.map;
|
||||
if (!map1 || !map2) return;
|
||||
const [w1, h1] = [map1[0].length, map1.length];
|
||||
const [w2, h2] = [map2[0].length, map2.length];
|
||||
if (w1 !== w2 || h1 !== h2) return;
|
||||
generatePair(data, id1, id2, map1, map2, [w1, h1]);
|
||||
progress.update(i + 1);
|
||||
});
|
||||
|
||||
progress.stop();
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
function parseAllData(data: Map<string, FloorData>): MinamoDataset {
|
||||
const length = data.size;
|
||||
const totalCount = Math.round((length * (length - 1)) / 2);
|
||||
|
||||
const pairs = choosePair(length, 10000);
|
||||
|
||||
console.log(
|
||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
|
||||
);
|
||||
|
||||
const trainData = generateDataset(data, pairs, [...data.keys()]);
|
||||
|
||||
const dataset: MinamoDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: trainData
|
||||
};
|
||||
|
||||
return dataset;
|
||||
}
|
||||
|
||||
function generateAssignedData(
|
||||
data1: Map<string, FloorData>,
|
||||
data2: Map<string, FloorData>,
|
||||
count: [number, number]
|
||||
): MinamoDataset {
|
||||
const length = data1.size + data2.size;
|
||||
const totalCount = data1.size * data2.size;
|
||||
const count1 = Math.min(count[0], data1.size);
|
||||
const count2 = Math.min(count[1], data2.size);
|
||||
const keys1 = [...data1.keys()];
|
||||
const keys2 = [...data2.keys()];
|
||||
const choose1 = chooseFrom(keys1, count1);
|
||||
|
||||
const trainData: Record<string, MinamoTrainData> = {};
|
||||
|
||||
console.log(
|
||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
|
||||
count1 * count2
|
||||
} 个组合`
|
||||
);
|
||||
|
||||
const progress = new SingleBar({}, Presets.shades_classic);
|
||||
progress.start(count1 * count2, 0);
|
||||
let n = 0;
|
||||
|
||||
for (const key1 of choose1) {
|
||||
const choose2 = chooseFrom(keys2, count2);
|
||||
for (const key2 of choose2) {
|
||||
const { map: map1 } = data1.get(key1)!;
|
||||
const { map: map2 } = data2.get(key2)!;
|
||||
if (!map1 || !map2) continue;
|
||||
const [w1, h1] = [map1[0].length, map1.length];
|
||||
const [w2, h2] = [map2[0].length, map2.length];
|
||||
if (w1 !== w2 || h1 !== h2) continue;
|
||||
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
|
||||
n++;
|
||||
progress.update(n);
|
||||
}
|
||||
}
|
||||
|
||||
progress.stop();
|
||||
|
||||
const dataset: MinamoDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: trainData
|
||||
};
|
||||
|
||||
return dataset;
|
||||
}
|
||||
|
||||
(async () => {
|
||||
if (!assigned) {
|
||||
const towers = await Promise.all(
|
||||
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
|
||||
);
|
||||
const floors = await getAllFloors(...towers);
|
||||
const results = parseAllData(floors);
|
||||
const results = parseMinamo(floors);
|
||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
||||
const size = Object.keys(results.data).length;
|
||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
||||
|
||||
30
data/src/process/ginka.ts
Normal file
30
data/src/process/ginka.ts
Normal file
@ -0,0 +1,30 @@
|
||||
import { SingleBar, Presets } from 'cli-progress';
|
||||
import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types';
|
||||
import { FloorData } from 'src/utils';
|
||||
|
||||
export function parseGinka(data: Map<string, FloorData>) {
|
||||
const resolved: Record<string, GinkaTrainData> = {};
|
||||
|
||||
const progress = new SingleBar({}, Presets.shades_classic);
|
||||
progress.start(data.size, 0);
|
||||
let i = 0;
|
||||
|
||||
data.forEach((floor, key) => {
|
||||
const config = floor.config as GinkaConfig;
|
||||
const text = config.data[floor.id] ?? [];
|
||||
resolved[key] = {
|
||||
map: floor.map,
|
||||
size: [floor.map[0].length, floor.map.length],
|
||||
text: text
|
||||
};
|
||||
i++;
|
||||
progress.update(i);
|
||||
});
|
||||
|
||||
const dataset: GinkaDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: resolved
|
||||
};
|
||||
|
||||
return dataset;
|
||||
}
|
||||
395
data/src/process/minamo.ts
Normal file
395
data/src/process/minamo.ts
Normal file
@ -0,0 +1,395 @@
|
||||
import { SingleBar, Presets } from 'cli-progress';
|
||||
import { compareMap } from 'src/topology/compare';
|
||||
import { directions, tileType } from 'src/topology/graph';
|
||||
import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform';
|
||||
import { MinamoTrainData, MinamoDataset } from 'src/types';
|
||||
import { chooseFrom, FloorData } from 'src/utils';
|
||||
import { calculateVisualSimilarity } from 'src/vision/similarity';
|
||||
|
||||
function chooseN(maxCount: number, n: number) {
|
||||
return chooseFrom(
|
||||
Array(maxCount)
|
||||
.fill(0)
|
||||
.map((_, i) => i),
|
||||
n
|
||||
);
|
||||
}
|
||||
|
||||
function choosePair(n: number, max: number = 1000) {
|
||||
const totalCount = Math.round((n * (n - 1)) / 2);
|
||||
const count = Math.min(totalCount, max);
|
||||
const pairs: number[] = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
for (let j = i + 1; j < n; j++) {
|
||||
pairs.push(i * n + j);
|
||||
}
|
||||
}
|
||||
// 直接打乱后取前 count 个
|
||||
for (let i = pairs.length - 1; i > 0; i--) {
|
||||
let randIndex = Math.floor(Math.random() * (i + 1));
|
||||
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
|
||||
}
|
||||
|
||||
return pairs.slice(0, count);
|
||||
}
|
||||
|
||||
function transform(map: number[][], rot: number, flip: number) {
|
||||
let res = map;
|
||||
for (let i = 0; i < rot; i++) {
|
||||
res = rotateMap(res);
|
||||
}
|
||||
if (flip & 0b01) {
|
||||
res = mirrorMapX(res);
|
||||
}
|
||||
if (flip & 0b10) {
|
||||
res = mirrorMapY(res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
function generateTransformData(
|
||||
id1: string,
|
||||
id2: string,
|
||||
map1: number[][],
|
||||
map2: number[][],
|
||||
simi: number
|
||||
) {
|
||||
const types: [rot: number, flip: number][] = [];
|
||||
for (const rot of [0, 1, 2, 3]) {
|
||||
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
|
||||
if (rot === 0 && flip === 0) continue;
|
||||
types.push([rot, flip]);
|
||||
}
|
||||
}
|
||||
// 随机抽取最多一个
|
||||
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
||||
return trans
|
||||
.map(([rot, flip]) => {
|
||||
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
||||
const com2 = `${id1}.${rot}.${flip}:${id2}`;
|
||||
const com3 = `${id2}.${rot}.${flip}:${id1}`;
|
||||
const com4 = `${id2}.${rot}.${flip}:${id2}`;
|
||||
const choose = chooseFrom(
|
||||
[com1, com2, com3, com4],
|
||||
Math.floor(Math.random() * 2)
|
||||
);
|
||||
const res: [id: string, data: MinamoTrainData][] = [];
|
||||
if (choose.includes(com1)) {
|
||||
const t = transform(map1, rot, flip);
|
||||
res.push([
|
||||
com1,
|
||||
{
|
||||
map1: t,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: calculateVisualSimilarity(map1, t),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
if (choose.includes(com2)) {
|
||||
const t = transform(map1, rot, flip);
|
||||
res.push([
|
||||
com2,
|
||||
{
|
||||
map1: t,
|
||||
map2: map2,
|
||||
topoSimilarity: simi,
|
||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
if (choose.includes(com3)) {
|
||||
const t = transform(map2, rot, flip);
|
||||
res.push([
|
||||
com3,
|
||||
{
|
||||
map1: t,
|
||||
map2: map1,
|
||||
topoSimilarity: simi,
|
||||
visionSimilarity: calculateVisualSimilarity(t, map1),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
if (choose.includes(com4)) {
|
||||
const t = transform(map2, rot, flip);
|
||||
res.push([
|
||||
com4,
|
||||
{
|
||||
map1: t,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
||||
size: [map1[0].length, map1.length]
|
||||
}
|
||||
]);
|
||||
}
|
||||
|
||||
return res;
|
||||
})
|
||||
.flat();
|
||||
}
|
||||
|
||||
function generateSimilarData(id: string, map: number[][]) {
|
||||
// 生成最多两个微调地图
|
||||
const width = map[0].length;
|
||||
const height = map.length;
|
||||
const num = Math.floor(Math.random() * 1);
|
||||
const res: [id: string, data: MinamoTrainData][] = [];
|
||||
|
||||
for (let i = 0; i < num; i++) {
|
||||
const clone = map.map(v => v.slice());
|
||||
const prob = Math.random() * 0.3;
|
||||
for (let ny = 0; ny < height; ny++) {
|
||||
for (let nx = 0; nx < width; nx++) {
|
||||
if (Math.random() > prob) {
|
||||
// 有一定的概率进行微调
|
||||
continue;
|
||||
}
|
||||
if (Math.random() < 0.2) {
|
||||
// 20% 概率与旁边图块互换位置
|
||||
const [dx, dy] =
|
||||
directions[
|
||||
Math.floor(Math.random() * directions.length)
|
||||
];
|
||||
const px = nx + dx;
|
||||
const py = ny + dy;
|
||||
if (px < 0 || px >= width || py < 0 || py >= height) {
|
||||
continue;
|
||||
}
|
||||
[clone[ny][nx], clone[py][px]] = [
|
||||
clone[py][px],
|
||||
clone[ny][nx]
|
||||
];
|
||||
} else {
|
||||
// 80% 概率替换当前图块
|
||||
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
|
||||
}
|
||||
}
|
||||
}
|
||||
const id2 = `${id}.S${i}`;
|
||||
const sid = `${id}:${id2}`;
|
||||
const simi = compareMap(id, id2, map, clone);
|
||||
|
||||
res.push([
|
||||
sid,
|
||||
{
|
||||
map1: map,
|
||||
map2: clone,
|
||||
size: [width, height],
|
||||
topoSimilarity: simi,
|
||||
visionSimilarity: calculateVisualSimilarity(map, clone)
|
||||
}
|
||||
]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
export function generateTrainData(
|
||||
id1: string,
|
||||
id2: string,
|
||||
map1: number[][],
|
||||
map2: number[][],
|
||||
size: [number, number]
|
||||
) {
|
||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||
const train: MinamoTrainData = {
|
||||
map1,
|
||||
map2,
|
||||
topoSimilarity,
|
||||
visionSimilarity,
|
||||
size: size
|
||||
};
|
||||
const data: MinamoTrainData[] = [];
|
||||
data.push(train);
|
||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||
const self1 = `${id1}:${id1}`;
|
||||
const self2 = `${id2}:${id2}`;
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||
if (selfTrain.includes(self1)) {
|
||||
const selfTrain1: MinamoTrainData = {
|
||||
map1: map1,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data.push(selfTrain1);
|
||||
}
|
||||
if (selfTrain.includes(self2)) {
|
||||
const selfTrain2: MinamoTrainData = {
|
||||
map1: map2,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data.push(selfTrain2);
|
||||
}
|
||||
const transform = generateTransformData(
|
||||
id1,
|
||||
id2,
|
||||
map1,
|
||||
map2,
|
||||
topoSimilarity
|
||||
);
|
||||
const similar = generateSimilarData(id1, map1);
|
||||
return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])];
|
||||
}
|
||||
|
||||
export function generatePair(
|
||||
data: Record<string, MinamoTrainData>,
|
||||
id1: string,
|
||||
id2: string,
|
||||
map1: number[][],
|
||||
map2: number[][],
|
||||
size: [number, number]
|
||||
) {
|
||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||
const train: MinamoTrainData = {
|
||||
map1,
|
||||
map2,
|
||||
topoSimilarity,
|
||||
visionSimilarity,
|
||||
size: size
|
||||
};
|
||||
data[`${id1}:${id2}`] = train;
|
||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||
const self1 = `${id1}:${id1}`;
|
||||
const self2 = `${id2}:${id2}`;
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
||||
const selfTrain1: MinamoTrainData = {
|
||||
map1: map1,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data[`${id1}:${id1}`] = selfTrain1;
|
||||
}
|
||||
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
||||
const selfTrain2: MinamoTrainData = {
|
||||
map1: map2,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data[`${id2}:${id2}`] = selfTrain2;
|
||||
}
|
||||
// 翻转、旋转训练集
|
||||
Object.assign(
|
||||
data,
|
||||
Object.fromEntries(
|
||||
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
||||
)
|
||||
);
|
||||
// 地图微调训练集
|
||||
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
||||
}
|
||||
|
||||
function generateDataset(
|
||||
floors: Map<string, FloorData>,
|
||||
pairs: number[],
|
||||
floorIds: string[]
|
||||
): Record<string, MinamoTrainData> {
|
||||
const data: Record<string, MinamoTrainData> = {};
|
||||
|
||||
const progress = new SingleBar({}, Presets.shades_classic);
|
||||
|
||||
progress.start(pairs.length, 0);
|
||||
|
||||
pairs.forEach((v, i) => {
|
||||
const num1 = Math.floor(v / floorIds.length);
|
||||
const num2 = v % floorIds.length;
|
||||
const id1 = floorIds[num1];
|
||||
const id2 = floorIds[num2];
|
||||
const map1 = floors.get(id1)?.map;
|
||||
const map2 = floors.get(id2)?.map;
|
||||
if (!map1 || !map2) return;
|
||||
const [w1, h1] = [map1[0].length, map1.length];
|
||||
const [w2, h2] = [map2[0].length, map2.length];
|
||||
if (w1 !== w2 || h1 !== h2) return;
|
||||
generatePair(data, id1, id2, map1, map2, [w1, h1]);
|
||||
progress.update(i + 1);
|
||||
});
|
||||
|
||||
progress.stop();
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
export function parseMinamo(data: Map<string, FloorData>): MinamoDataset {
|
||||
const length = data.size;
|
||||
const totalCount = Math.round((length * (length - 1)) / 2);
|
||||
|
||||
const pairs = choosePair(length, 10000);
|
||||
|
||||
console.log(
|
||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
|
||||
);
|
||||
|
||||
const trainData = generateDataset(data, pairs, [...data.keys()]);
|
||||
|
||||
const dataset: MinamoDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: trainData
|
||||
};
|
||||
|
||||
return dataset;
|
||||
}
|
||||
|
||||
export function generateAssignedData(
|
||||
data1: Map<string, FloorData>,
|
||||
data2: Map<string, FloorData>,
|
||||
count: [number, number]
|
||||
): MinamoDataset {
|
||||
const length = data1.size + data2.size;
|
||||
const totalCount = data1.size * data2.size;
|
||||
const count1 = Math.min(count[0], data1.size);
|
||||
const count2 = Math.min(count[1], data2.size);
|
||||
const keys1 = [...data1.keys()];
|
||||
const keys2 = [...data2.keys()];
|
||||
const choose1 = chooseFrom(keys1, count1);
|
||||
|
||||
const trainData: Record<string, MinamoTrainData> = {};
|
||||
|
||||
console.log(
|
||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
|
||||
count1 * count2
|
||||
} 个组合`
|
||||
);
|
||||
|
||||
const progress = new SingleBar({}, Presets.shades_classic);
|
||||
progress.start(count1 * count2, 0);
|
||||
let n = 0;
|
||||
|
||||
for (const key1 of choose1) {
|
||||
const choose2 = chooseFrom(keys2, count2);
|
||||
for (const key2 of choose2) {
|
||||
const { map: map1 } = data1.get(key1)!;
|
||||
const { map: map2 } = data2.get(key2)!;
|
||||
if (!map1 || !map2) continue;
|
||||
const [w1, h1] = [map1[0].length, map1.length];
|
||||
const [w2, h2] = [map2[0].length, map2.length];
|
||||
if (w1 !== w2 || h1 !== h2) continue;
|
||||
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
|
||||
n++;
|
||||
progress.update(n);
|
||||
}
|
||||
}
|
||||
|
||||
progress.stop();
|
||||
|
||||
const dataset: MinamoDataset = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: trainData
|
||||
};
|
||||
|
||||
return dataset;
|
||||
}
|
||||
@ -11,3 +11,33 @@ export interface TowerInfo {
|
||||
floorIds: string[];
|
||||
config: BaseConfig;
|
||||
}
|
||||
|
||||
export interface GinkaConfig extends BaseConfig {
|
||||
data: Record<string, string[]>;
|
||||
}
|
||||
|
||||
export interface GinkaTrainData {
|
||||
text: string[];
|
||||
map: number[][];
|
||||
size: [number, number];
|
||||
}
|
||||
|
||||
export interface GinkaDataset {
|
||||
datasetId: number;
|
||||
data: Record<string, GinkaTrainData>;
|
||||
}
|
||||
|
||||
export interface MinamoConfig extends BaseConfig {}
|
||||
|
||||
export interface MinamoTrainData {
|
||||
map1: number[][];
|
||||
map2: number[][];
|
||||
topoSimilarity: number;
|
||||
visionSimilarity: number;
|
||||
size: [number, number];
|
||||
}
|
||||
|
||||
export interface MinamoDataset {
|
||||
datasetId: number;
|
||||
data: Record<string, MinamoTrainData>;
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
@ -16,6 +17,12 @@ def load_data(path: str):
|
||||
|
||||
return data_list
|
||||
|
||||
def load_minamo_gan_data(data: list):
|
||||
res = list()
|
||||
for one in data:
|
||||
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
||||
return res
|
||||
|
||||
class GinkaDataset(Dataset):
|
||||
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
@ -40,4 +47,45 @@ class GinkaDataset(Dataset):
|
||||
"target_topo_feat": topo_feat,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
|
||||
class MinamoGANDataset(Dataset):
|
||||
def __init__(self, refer_data_path):
|
||||
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
||||
self.data = list().extend(self.refer)
|
||||
|
||||
def set_data(self, data: list):
|
||||
self.data = data.extend(self.refer)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
map1, map2, vis_sim, topo_sim, review = item
|
||||
map1 = torch.ShortTensor(map1)
|
||||
map2 = torch.ShortTensor(map2)
|
||||
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
||||
if review:
|
||||
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.25)
|
||||
|
||||
if review:
|
||||
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
|
||||
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
|
||||
|
||||
graph1 = differentiable_convert_to_data(map1)
|
||||
graph2 = differentiable_convert_to_data(map2)
|
||||
|
||||
return (
|
||||
map1,
|
||||
map2,
|
||||
torch.FloatTensor([vis_sim]),
|
||||
torch.FloatTensor([topo_sim]),
|
||||
graph1,
|
||||
graph2
|
||||
)
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -253,11 +254,8 @@ class GinkaLoss(nn.Module):
|
||||
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
|
||||
minamo_loss = (1.0 - minamo_sim).mean()
|
||||
|
||||
print(
|
||||
minamo_loss.item(),
|
||||
class_loss.item(),
|
||||
entrance_loss.item(),
|
||||
count_loss.item()
|
||||
tqdm.write(
|
||||
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}"
|
||||
)
|
||||
|
||||
losses = [
|
||||
@ -267,7 +265,4 @@ class GinkaLoss(nn.Module):
|
||||
count_loss * self.weight[3]
|
||||
]
|
||||
|
||||
# 梯度归一化
|
||||
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
||||
total_loss = sum(scaled_losses)
|
||||
return total_loss, sum(losses)
|
||||
return sum(losses)
|
||||
|
||||
@ -78,7 +78,7 @@ def train():
|
||||
_, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
losses.backward()
|
||||
@ -111,7 +111,7 @@ def train():
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += losses.item()
|
||||
|
||||
avg_val_loss = loss_val / len(dataloader_val)
|
||||
|
||||
305
ginka/train_gan.py
Normal file
305
ginka/train_gan.py
Normal file
@ -0,0 +1,305 @@
|
||||
import argparse
|
||||
import socket
|
||||
import struct
|
||||
import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
import numpy as np
|
||||
from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss
|
||||
from .dataset import GinkaDataset, MinamoGANDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from minamo.model.loss import MinamoLoss
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_GINKA = 30
|
||||
EPOCHS_MINAMO = 10
|
||||
SOCKET_PATH = "./tmp/ginka_uds"
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
os.makedirs("tmp", exist_ok=True)
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
||||
parser.add_argument("--from_cycle", type=int, default=2)
|
||||
parser.add_argument("--to_cycle", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def parse_ginka_batch(batch):
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
|
||||
return target, target_vision_feat, target_topo_feat, feat_vec
|
||||
|
||||
def parse_minamo_batch(batch):
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||
map2 = map2.to(device)
|
||||
topo_simi = topo_simi.to(device)
|
||||
vision_simi = vision_simi.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
return map1, map2, vision_simi, topo_simi, graph1, graph2
|
||||
|
||||
def send_all(sock, data):
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
sent = sock.send(data[total_sent:])
|
||||
if sent == 0:
|
||||
raise RuntimeError("Socket connection broken")
|
||||
total_sent += sent
|
||||
|
||||
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||
# 数据通讯 node 输出协议,单位字节:
|
||||
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||
# 1*tc - Compare count for every map tensor delivered.
|
||||
# 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
||||
# N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
||||
_, _, H, W = maps.shape
|
||||
tc_buf = sock.recv(2)
|
||||
rc_buf = sock.recv(2)
|
||||
tc = struct.unpack('>h', tc_buf)[0]
|
||||
rc = struct.unpack('>h', rc_buf)[0]
|
||||
count_buf = sock.recv(1 * tc)
|
||||
count: list = struct.unpack(f">{tc}b", count_buf)[0]
|
||||
N = sum(count)
|
||||
sim_buf = sock.recv(2 * 4 * (N + rc))
|
||||
com_buf = sock.recv(N * 1 * H * W)
|
||||
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes()
|
||||
|
||||
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0]
|
||||
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0]
|
||||
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list()
|
||||
|
||||
res = list()
|
||||
flatten_idx = 0
|
||||
# 读取当前这一轮生成器的数据
|
||||
for idx in range(N):
|
||||
com_count = count[idx]
|
||||
for i in range(com_count):
|
||||
com_start = flatten_idx * H * W
|
||||
com_end = (flatten_idx + 1) * H * W
|
||||
vis_sim = sim[flatten_idx * 2]
|
||||
topo_sim = sim[flatten_idx * 2 + 1]
|
||||
com_data = com[com_start:com_end]
|
||||
flatten_idx += 1
|
||||
com_map = np.fromiter(com_data, np.int8).view(H, W)
|
||||
# map1, map2, vision_similarity, topo_similarity, is_review
|
||||
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
||||
|
||||
return res
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
||||
|
||||
ginka = GinkaModel()
|
||||
ginka.to(device)
|
||||
minamo = MinamoModel(32)
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
# 准备数据集
|
||||
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
||||
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||
minamo_dataset = MinamoGANDataset()
|
||||
minamo_dataset_val = MinamoGANDataset()
|
||||
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion_ginka = GinkaLoss(minamo)
|
||||
|
||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion_minamo = MinamoLoss()
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# 与 node 端通讯
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.bind(SOCKET_PATH)
|
||||
server.listen(1)
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
if args.load_optim:
|
||||
optimizer_ginka.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
else:
|
||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||
criterion_ginka.weight[0] = 0.0
|
||||
|
||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle)):
|
||||
# -------------------- 训练生成器
|
||||
gen_list: np.ndarray = np.empty(np.int8)
|
||||
prob_list: np.ndarray = np.empty(np.float32)
|
||||
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"):
|
||||
ginka.train()
|
||||
minamo.eval()
|
||||
total_loss = 0
|
||||
|
||||
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
||||
if not args.resume and epoch == 10:
|
||||
criterion_ginka.weight[0] = 0.5
|
||||
|
||||
for batch in ginka_dataloader:
|
||||
# 数据迁移到设备
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
# 前向传播
|
||||
optimizer_ginka.zero_grad()
|
||||
_, output_softmax = ginka(feat_vec)
|
||||
# 计算损失
|
||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
# 反向传播
|
||||
losses.backward()
|
||||
optimizer_ginka.step()
|
||||
total_loss += losses.item()
|
||||
|
||||
avg_loss = total_loss / len(ginka_dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler_ginka.step()
|
||||
|
||||
if (epoch + 1) % 5 == 0:
|
||||
loss_val = 0
|
||||
ginka.eval()
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in ginka_dataloader_val:
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
output, output_softmax = ginka(feat_vec)
|
||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += losses.item()
|
||||
if epoch + 1 == EPOCHS_GINKA:
|
||||
# 最后一次验证的时候顺带生成图片
|
||||
prob = output_softmax.cpu().numpy()
|
||||
np.concatenate((prob_list, prob), axis=1)
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=1)
|
||||
for matrix in map_matrix:
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
idx += 1
|
||||
|
||||
avg_val_loss = loss_val / len(ginka_dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka.pth")
|
||||
|
||||
# -------------------- 生成 Minamo 的训练数据
|
||||
|
||||
# 数据通讯 python 输出协议,单位字节:
|
||||
# 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||
N, H, W = gen_list.shape
|
||||
gen_bytes = gen_list.astype(np.int8).tobytes()
|
||||
buf = bytearray()
|
||||
buf.extend(struct.pack('>h', N)) # Tensor count
|
||||
buf.extend(struct.pack('>b', H)) # Map height
|
||||
buf.extend(struct.pack('>b', W)) # Map width
|
||||
buf.extend(gen_bytes) # Map tensor
|
||||
server.sendall(buf)
|
||||
data = parse_minamo_data(server, prob_list)
|
||||
minamo_dataset.set_data(data)
|
||||
|
||||
# -------------------- 训练判别器
|
||||
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
||||
ginka.eval()
|
||||
minamo.train()
|
||||
total_loss = 0
|
||||
|
||||
for batch in minamo_dataloader:
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
|
||||
|
||||
if map1.shape[0] == 1:
|
||||
continue
|
||||
|
||||
# 前向传播
|
||||
optimizer_minamo.zero_grad()
|
||||
vision_feat1, topo_feat1 = minamo(map1, graph1)
|
||||
vision_feat2, topo_feat2 = minamo(map2, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer_minamo.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
ave_loss = total_loss / len(minamo_dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||
|
||||
scheduler_minamo.step()
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 5 == 0:
|
||||
minamo.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(minamo_dataloader_val, leave=False):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
||||
|
||||
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(minamo_dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/ginka.pth")
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
@ -10,59 +10,11 @@ from minamo.model.model import MinamoModel
|
||||
from .dataset import GinkaDataset
|
||||
from .model.loss import GinkaLoss
|
||||
from .model.model import GinkaModel
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs('result/ginka_img', exist_ok=True)
|
||||
|
||||
def blend_alpha(bg, fg, alpha):
|
||||
""" 使用 alpha 通道混合前景图块和背景图 """
|
||||
for c in range(3): # 只混合 RGB 三个通道
|
||||
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
|
||||
return bg
|
||||
|
||||
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
"""
|
||||
使用OpenCV加速的版本(适合大尺寸地图)
|
||||
:param map_matrix: [H, W] 的numpy数组
|
||||
:param tile_set: 字典 {tile_id: cv2图像(BGR格式)}
|
||||
:param tile_size: 图块边长(像素)
|
||||
"""
|
||||
H, W = map_matrix.shape # 获取地图尺寸
|
||||
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
|
||||
|
||||
# 遍历地图矩阵
|
||||
for row in range(H):
|
||||
for col in range(W):
|
||||
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
|
||||
x, y = col * tile_size, row * tile_size # 计算像素位置
|
||||
|
||||
# 先绘制地面(0)
|
||||
if '0' in tile_set:
|
||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
|
||||
|
||||
if tile_index == '11':
|
||||
if row == 0:
|
||||
tile_index = '11_1'
|
||||
elif row == W - 1:
|
||||
tile_index = '11_3'
|
||||
elif col == 0:
|
||||
tile_index = '11_2'
|
||||
elif col == H - 1:
|
||||
tile_index = '11_4'
|
||||
|
||||
# 叠加其他透明图块
|
||||
if tile_index in tile_set and tile_index != 0:
|
||||
tile_rgba = tile_set[tile_index]
|
||||
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
|
||||
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
|
||||
|
||||
# 混合当前图块到背景
|
||||
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
|
||||
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
|
||||
)
|
||||
|
||||
return canvas
|
||||
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
|
||||
50
shared/image.py
Normal file
50
shared/image.py
Normal file
@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
|
||||
def blend_alpha(bg, fg, alpha):
|
||||
""" 使用 alpha 通道混合前景图块和背景图 """
|
||||
for c in range(3): # 只混合 RGB 三个通道
|
||||
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
|
||||
return bg
|
||||
|
||||
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
"""
|
||||
使用OpenCV加速的版本(适合大尺寸地图)
|
||||
:param map_matrix: [H, W] 的numpy数组
|
||||
:param tile_set: 字典 {tile_id: cv2图像(BGR格式)}
|
||||
:param tile_size: 图块边长(像素)
|
||||
"""
|
||||
H, W = map_matrix.shape # 获取地图尺寸
|
||||
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
|
||||
|
||||
# 遍历地图矩阵
|
||||
for row in range(H):
|
||||
for col in range(W):
|
||||
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
|
||||
x, y = col * tile_size, row * tile_size # 计算像素位置
|
||||
|
||||
# 先绘制地面(0)
|
||||
if '0' in tile_set:
|
||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
|
||||
|
||||
if tile_index == '11':
|
||||
if row == 0:
|
||||
tile_index = '11_1'
|
||||
elif row == W - 1:
|
||||
tile_index = '11_3'
|
||||
elif col == 0:
|
||||
tile_index = '11_2'
|
||||
elif col == H - 1:
|
||||
tile_index = '11_4'
|
||||
|
||||
# 叠加其他透明图块
|
||||
if tile_index in tile_set and tile_index != 0:
|
||||
tile_rgba = tile_set[tile_index]
|
||||
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
|
||||
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
|
||||
|
||||
# 混合当前图块到背景
|
||||
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
|
||||
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
|
||||
)
|
||||
|
||||
return canvas
|
||||
Loading…
Reference in New Issue
Block a user