feat: 使用 IPC 通讯并实现单脚本对抗训练

This commit is contained in:
unanmed 2025-04-01 22:41:23 +08:00
parent 4721e9a141
commit 8dea79a9f0
13 changed files with 1004 additions and 474 deletions

View File

@ -8,6 +8,7 @@
"minamo": "tsx ./src/minamo.ts",
"merge": "tsx ./src/merge.ts",
"review": "tsx ./src/review.ts",
"gan": "tsx ./src/gan.ts",
"test:topo": "tsx ./src/topology/test.ts",
"test:vision": "tsx ./src/vision/test.ts"
},

131
data/src/gan.ts Normal file
View File

@ -0,0 +1,131 @@
import { createConnection, Socket } from 'net';
import { chooseFrom, FloorData, readOne } from './utils';
import { MinamoTrainData } from './types';
import { generateTrainData } from './process/minamo';
const SOCKET_FILE = '../tmp/ginka_uds';
const [refer] = process.argv.slice(2);
let id = 0;
function readMap(count: number, buffer: Buffer, h: number, w: number) {
const area = w * h;
const maps: number[][][] = Array.from<number[][]>({
length: count
}).map(() => {
return Array.from<number[]>({ length: h }).map(() => {
return Array.from<number>({ length: w }).fill(0);
});
});
buffer.subarray(4).forEach((v, i) => {
const n = Math.floor(i / area);
const y = Math.floor((i % area) / w);
const x = i % w;
maps[n][y][x] = v;
});
return maps;
}
function generateGANData(
keys: string[],
refer: Map<string, FloorData>,
map: number[][]
) {
const id2 = `$${id++}`;
const toTrain = chooseFrom(keys, 4);
const data = toTrain.map<MinamoTrainData[]>(v => {
const floor = refer.get(v);
if (!floor) return [];
const size1: [number, number] = [floor.map[0].length, floor.map.length];
const size2: [number, number] = [map[0].length, map.length];
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
return generateTrainData(v, id2, floor.map, map, size1);
});
return data.flat();
}
(async () => {
const referTower = await readOne(refer);
const keys = [...referTower.keys()];
const client = createConnection(SOCKET_FILE, () => {
console.log(`UDS IPC connected successfully.`);
// 发送四字节数据表示连接成功
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
});
client.on('data', buffer => {
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化
// 数据通讯 node 输入协议,单位字节:
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
const count = buffer.readInt16BE();
if (buffer.length - 4 !== count * 32 * 32) {
client.write(`ERROR: byte length not match.`);
return [];
}
const h = buffer.readInt8(2);
const w = buffer.readInt8(3);
const map = readMap(count, buffer, h, w);
const simData = map.map(v => generateGANData(keys, referTower, v));
const rc = 0;
const compareData = simData.flat();
const reviewData: MinamoTrainData[] = [];
// 数据通讯 node 输出协议,单位字节:
// 2 - Tensor count; 2 - Review count. Review is right behind train data;
// 1*tc - Compare count for every map tensor delivered.
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
const toSend = Buffer.alloc(
2 + // Tensor count
2 + // Review count
count + // Compare count
2 * (count + rc) + // Similarity data
compareData.length * 1 * h * w + // Compare map
rc * 2 * h * w, // Review map
0
);
let offset = 0;
toSend.writeInt16BE(count); // Tensor count
toSend.writeInt16BE(0, 2); // Review count
offset += 2 + 2;
// Compare count
toSend.set(
simData.map(v => v.length),
offset
);
offset += count;
// Similarity data
compareData.forEach(v => {
toSend.writeFloatBE(v.visionSimilarity, offset);
offset += 4;
toSend.writeFloatBE(v.topoSimilarity, offset);
offset += 4;
});
// Compare map
toSend.set(
compareData.map(v => v.map1).flat(2),
offset // Set from Compare map
);
offset += compareData.length * 1 * h * w;
// Review map
toSend.set(
reviewData.map(v => [v.map1, v.map2]).flat(3),
offset // Set from last chunk
);
client.write(toSend);
});
client.on('end', () => {
console.log(`Connection lose.`);
});
client.on('error', () => {
client.end();
});
})();

View File

@ -1,61 +1,15 @@
import { writeFile } from 'fs-extra';
import { FloorData, getAllFloors, parseTowerInfo } from './utils';
import { Presets, SingleBar } from 'cli-progress';
interface GinkaConfig {
clip: {
defaults: [number, number, number, number];
special: Record<string, [number, number, number, number]>;
};
data: Record<string, string[]>;
}
interface GinkaTrainData {
text: string[];
map: number[][];
size: [number, number];
}
interface GinkaDataset {
datasetId: number;
data: Record<string, GinkaTrainData>;
}
import { getAllFloors, parseTowerInfo } from './utils';
import { parseGinka } from './process/ginka';
const [output, ...list] = process.argv.slice(2);
function parseAllData(data: Map<string, FloorData>) {
const resolved: Record<string, GinkaTrainData> = {};
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(data.size, 0);
let i = 0;
data.forEach((floor, key) => {
const config = floor.config as GinkaConfig;
const text = config.data[floor.id] ?? [];
resolved[key] = {
map: floor.map,
size: [floor.map[0].length, floor.map.length],
text: text
};
i++;
progress.update(i);
});
const dataset: GinkaDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: resolved
};
return dataset;
}
(async () => {
const towers = await Promise.all(
list.map(v => parseTowerInfo(v, 'ginka-config.json'))
);
const floors = await getAllFloors(...towers);
const results = parseAllData(floors);
const results = parseGinka(floors);
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
const size = Object.keys(results.data).length;
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`);

View File

@ -1,32 +1,6 @@
import { writeFile } from 'fs-extra';
import {
FloorData,
readOne,
getAllFloors,
parseTowerInfo,
chooseFrom
} from './utils';
import { compareMap } from './topology/compare';
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
import { directions, tileType } from './topology/graph';
import { calculateVisualSimilarity } from './vision/similarity';
import { BaseConfig } from './types';
import { Presets, SingleBar } from 'cli-progress';
interface MinamoConfig extends BaseConfig {}
interface MinamoTrainData {
map1: number[][];
map2: number[][];
topoSimilarity: number;
visionSimilarity: number;
size: [number, number];
}
interface MinamoDataset {
datasetId: number;
data: Record<string, MinamoTrainData>;
}
import { readOne, getAllFloors, parseTowerInfo } from './utils';
import { generateAssignedData, parseMinamo } from './process/minamo';
const [output, ...list] = process.argv.slice(2);
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
@ -40,348 +14,13 @@ function parseAssigned(arg: string): [number, number] {
return [parseInt(a) || 100, parseInt(b) || 100];
}
function chooseN(maxCount: number, n: number) {
return chooseFrom(
Array(maxCount)
.fill(0)
.map((_, i) => i),
n
);
}
function choosePair(n: number, max: number = 1000) {
const totalCount = Math.round((n * (n - 1)) / 2);
const count = Math.min(totalCount, max);
const pairs: number[] = [];
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
pairs.push(i * n + j);
}
}
// 直接打乱后取前 count 个
for (let i = pairs.length - 1; i > 0; i--) {
let randIndex = Math.floor(Math.random() * (i + 1));
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
}
return pairs.slice(0, count);
}
function transform(map: number[][], rot: number, flip: number) {
let res = map;
for (let i = 0; i < rot; i++) {
res = rotateMap(res);
}
if (flip & 0b01) {
res = mirrorMapX(res);
}
if (flip & 0b10) {
res = mirrorMapY(res);
}
return res;
}
function generateTransformData(
id1: string,
id2: string,
map1: number[][],
map2: number[][],
simi: number
) {
const types: [rot: number, flip: number][] = [];
for (const rot of [0, 1, 2, 3]) {
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
if (rot === 0 && flip === 0) continue;
types.push([rot, flip]);
}
}
// 随机抽取最多一个
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
return trans
.map(([rot, flip]) => {
const com1 = `${id1}.${rot}.${flip}:${id1}`;
const com2 = `${id1}.${rot}.${flip}:${id2}`;
const com3 = `${id2}.${rot}.${flip}:${id1}`;
const com4 = `${id2}.${rot}.${flip}:${id2}`;
const choose = chooseFrom(
[com1, com2, com3, com4],
Math.floor(Math.random() * 2)
);
const res: [id: string, data: MinamoTrainData][] = [];
if (choose.includes(com1)) {
const t = transform(map1, rot, flip);
res.push([
com1,
{
map1: t,
map2: map1,
topoSimilarity: 1,
visionSimilarity: calculateVisualSimilarity(map1, t),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com2)) {
const t = transform(map1, rot, flip);
res.push([
com2,
{
map1: t,
map2: map2,
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(t, map2),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com3)) {
const t = transform(map2, rot, flip);
res.push([
com3,
{
map1: t,
map2: map1,
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(t, map1),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com4)) {
const t = transform(map2, rot, flip);
res.push([
com4,
{
map1: t,
map2: map2,
topoSimilarity: 1,
visionSimilarity: calculateVisualSimilarity(t, map2),
size: [map1[0].length, map1.length]
}
]);
}
return res;
})
.flat();
}
function generateSimilarData(id: string, map: number[][]) {
// 生成最多两个微调地图
const width = map[0].length;
const height = map.length;
const num = Math.floor(Math.random() * 2);
const res: [id: string, data: MinamoTrainData][] = [];
for (let i = 0; i < num; i++) {
const clone = map.map(v => v.slice());
const prob = Math.random() * 0.3;
for (let ny = 0; ny < height; ny++) {
for (let nx = 0; nx < width; nx++) {
if (Math.random() > prob) {
// 有一定的概率进行微调
continue;
}
if (Math.random() < 0.2) {
// 20% 概率与旁边图块互换位置
const [dx, dy] =
directions[
Math.floor(Math.random() * directions.length)
];
const px = nx + dx;
const py = ny + dy;
if (px < 0 || px >= width || py < 0 || py >= height) {
continue;
}
[clone[ny][nx], clone[py][px]] = [
clone[py][px],
clone[ny][nx]
];
} else {
// 80% 概率替换当前图块
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
}
}
}
const id2 = `${id}.S${i}`;
const sid = `${id}:${id2}`;
const simi = compareMap(id, id2, map, clone);
res.push([
sid,
{
map1: map,
map2: clone,
size: [width, height],
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(map, clone)
}
]);
}
return res;
}
function generatePair(
data: Record<string, MinamoTrainData>,
id1: string,
id2: string,
map1: number[][],
map2: number[][],
size: [number, number]
) {
const topoSimilarity = compareMap(id1, id2, map1, map2);
const visionSimilarity = calculateVisualSimilarity(map1, map2);
const train: MinamoTrainData = {
map1,
map2,
topoSimilarity,
visionSimilarity,
size: size
};
data[`${id1}:${id2}`] = train;
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data[`${id1}:${id1}`] = selfTrain1;
}
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data[`${id2}:${id2}`] = selfTrain2;
}
// 翻转、旋转训练集
Object.assign(
data,
Object.fromEntries(
generateTransformData(id1, id2, map1, map2, topoSimilarity)
)
);
// 地图微调训练集
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
}
function generateDataset(
floors: Map<string, FloorData>,
pairs: number[],
floorIds: string[]
): Record<string, MinamoTrainData> {
const data: Record<string, MinamoTrainData> = {};
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(pairs.length, 0);
pairs.forEach((v, i) => {
const num1 = Math.floor(v / floorIds.length);
const num2 = v % floorIds.length;
const id1 = floorIds[num1];
const id2 = floorIds[num2];
const map1 = floors.get(id1)?.map;
const map2 = floors.get(id2)?.map;
if (!map1 || !map2) return;
const [w1, h1] = [map1[0].length, map1.length];
const [w2, h2] = [map2[0].length, map2.length];
if (w1 !== w2 || h1 !== h2) return;
generatePair(data, id1, id2, map1, map2, [w1, h1]);
progress.update(i + 1);
});
progress.stop();
return data;
}
function parseAllData(data: Map<string, FloorData>): MinamoDataset {
const length = data.size;
const totalCount = Math.round((length * (length - 1)) / 2);
const pairs = choosePair(length, 10000);
console.log(
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
);
const trainData = generateDataset(data, pairs, [...data.keys()]);
const dataset: MinamoDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: trainData
};
return dataset;
}
function generateAssignedData(
data1: Map<string, FloorData>,
data2: Map<string, FloorData>,
count: [number, number]
): MinamoDataset {
const length = data1.size + data2.size;
const totalCount = data1.size * data2.size;
const count1 = Math.min(count[0], data1.size);
const count2 = Math.min(count[1], data2.size);
const keys1 = [...data1.keys()];
const keys2 = [...data2.keys()];
const choose1 = chooseFrom(keys1, count1);
const trainData: Record<string, MinamoTrainData> = {};
console.log(
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
count1 * count2
} `
);
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(count1 * count2, 0);
let n = 0;
for (const key1 of choose1) {
const choose2 = chooseFrom(keys2, count2);
for (const key2 of choose2) {
const { map: map1 } = data1.get(key1)!;
const { map: map2 } = data2.get(key2)!;
if (!map1 || !map2) continue;
const [w1, h1] = [map1[0].length, map1.length];
const [w2, h2] = [map2[0].length, map2.length];
if (w1 !== w2 || h1 !== h2) continue;
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
n++;
progress.update(n);
}
}
progress.stop();
const dataset: MinamoDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: trainData
};
return dataset;
}
(async () => {
if (!assigned) {
const towers = await Promise.all(
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
);
const floors = await getAllFloors(...towers);
const results = parseAllData(floors);
const results = parseMinamo(floors);
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
const size = Object.keys(results.data).length;
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);

30
data/src/process/ginka.ts Normal file
View File

@ -0,0 +1,30 @@
import { SingleBar, Presets } from 'cli-progress';
import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types';
import { FloorData } from 'src/utils';
export function parseGinka(data: Map<string, FloorData>) {
const resolved: Record<string, GinkaTrainData> = {};
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(data.size, 0);
let i = 0;
data.forEach((floor, key) => {
const config = floor.config as GinkaConfig;
const text = config.data[floor.id] ?? [];
resolved[key] = {
map: floor.map,
size: [floor.map[0].length, floor.map.length],
text: text
};
i++;
progress.update(i);
});
const dataset: GinkaDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: resolved
};
return dataset;
}

395
data/src/process/minamo.ts Normal file
View File

@ -0,0 +1,395 @@
import { SingleBar, Presets } from 'cli-progress';
import { compareMap } from 'src/topology/compare';
import { directions, tileType } from 'src/topology/graph';
import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform';
import { MinamoTrainData, MinamoDataset } from 'src/types';
import { chooseFrom, FloorData } from 'src/utils';
import { calculateVisualSimilarity } from 'src/vision/similarity';
function chooseN(maxCount: number, n: number) {
return chooseFrom(
Array(maxCount)
.fill(0)
.map((_, i) => i),
n
);
}
function choosePair(n: number, max: number = 1000) {
const totalCount = Math.round((n * (n - 1)) / 2);
const count = Math.min(totalCount, max);
const pairs: number[] = [];
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
pairs.push(i * n + j);
}
}
// 直接打乱后取前 count 个
for (let i = pairs.length - 1; i > 0; i--) {
let randIndex = Math.floor(Math.random() * (i + 1));
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
}
return pairs.slice(0, count);
}
function transform(map: number[][], rot: number, flip: number) {
let res = map;
for (let i = 0; i < rot; i++) {
res = rotateMap(res);
}
if (flip & 0b01) {
res = mirrorMapX(res);
}
if (flip & 0b10) {
res = mirrorMapY(res);
}
return res;
}
function generateTransformData(
id1: string,
id2: string,
map1: number[][],
map2: number[][],
simi: number
) {
const types: [rot: number, flip: number][] = [];
for (const rot of [0, 1, 2, 3]) {
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
if (rot === 0 && flip === 0) continue;
types.push([rot, flip]);
}
}
// 随机抽取最多一个
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
return trans
.map(([rot, flip]) => {
const com1 = `${id1}.${rot}.${flip}:${id1}`;
const com2 = `${id1}.${rot}.${flip}:${id2}`;
const com3 = `${id2}.${rot}.${flip}:${id1}`;
const com4 = `${id2}.${rot}.${flip}:${id2}`;
const choose = chooseFrom(
[com1, com2, com3, com4],
Math.floor(Math.random() * 2)
);
const res: [id: string, data: MinamoTrainData][] = [];
if (choose.includes(com1)) {
const t = transform(map1, rot, flip);
res.push([
com1,
{
map1: t,
map2: map1,
topoSimilarity: 1,
visionSimilarity: calculateVisualSimilarity(map1, t),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com2)) {
const t = transform(map1, rot, flip);
res.push([
com2,
{
map1: t,
map2: map2,
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(t, map2),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com3)) {
const t = transform(map2, rot, flip);
res.push([
com3,
{
map1: t,
map2: map1,
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(t, map1),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com4)) {
const t = transform(map2, rot, flip);
res.push([
com4,
{
map1: t,
map2: map2,
topoSimilarity: 1,
visionSimilarity: calculateVisualSimilarity(t, map2),
size: [map1[0].length, map1.length]
}
]);
}
return res;
})
.flat();
}
function generateSimilarData(id: string, map: number[][]) {
// 生成最多两个微调地图
const width = map[0].length;
const height = map.length;
const num = Math.floor(Math.random() * 1);
const res: [id: string, data: MinamoTrainData][] = [];
for (let i = 0; i < num; i++) {
const clone = map.map(v => v.slice());
const prob = Math.random() * 0.3;
for (let ny = 0; ny < height; ny++) {
for (let nx = 0; nx < width; nx++) {
if (Math.random() > prob) {
// 有一定的概率进行微调
continue;
}
if (Math.random() < 0.2) {
// 20% 概率与旁边图块互换位置
const [dx, dy] =
directions[
Math.floor(Math.random() * directions.length)
];
const px = nx + dx;
const py = ny + dy;
if (px < 0 || px >= width || py < 0 || py >= height) {
continue;
}
[clone[ny][nx], clone[py][px]] = [
clone[py][px],
clone[ny][nx]
];
} else {
// 80% 概率替换当前图块
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
}
}
}
const id2 = `${id}.S${i}`;
const sid = `${id}:${id2}`;
const simi = compareMap(id, id2, map, clone);
res.push([
sid,
{
map1: map,
map2: clone,
size: [width, height],
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(map, clone)
}
]);
}
return res;
}
export function generateTrainData(
id1: string,
id2: string,
map1: number[][],
map2: number[][],
size: [number, number]
) {
const topoSimilarity = compareMap(id1, id2, map1, map2);
const visionSimilarity = calculateVisualSimilarity(map1, map2);
const train: MinamoTrainData = {
map1,
map2,
topoSimilarity,
visionSimilarity,
size: size
};
const data: MinamoTrainData[] = [];
data.push(train);
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1)) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain1);
}
if (selfTrain.includes(self2)) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain2);
}
const transform = generateTransformData(
id1,
id2,
map1,
map2,
topoSimilarity
);
const similar = generateSimilarData(id1, map1);
return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])];
}
export function generatePair(
data: Record<string, MinamoTrainData>,
id1: string,
id2: string,
map1: number[][],
map2: number[][],
size: [number, number]
) {
const topoSimilarity = compareMap(id1, id2, map1, map2);
const visionSimilarity = calculateVisualSimilarity(map1, map2);
const train: MinamoTrainData = {
map1,
map2,
topoSimilarity,
visionSimilarity,
size: size
};
data[`${id1}:${id2}`] = train;
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data[`${id1}:${id1}`] = selfTrain1;
}
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data[`${id2}:${id2}`] = selfTrain2;
}
// 翻转、旋转训练集
Object.assign(
data,
Object.fromEntries(
generateTransformData(id1, id2, map1, map2, topoSimilarity)
)
);
// 地图微调训练集
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
}
function generateDataset(
floors: Map<string, FloorData>,
pairs: number[],
floorIds: string[]
): Record<string, MinamoTrainData> {
const data: Record<string, MinamoTrainData> = {};
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(pairs.length, 0);
pairs.forEach((v, i) => {
const num1 = Math.floor(v / floorIds.length);
const num2 = v % floorIds.length;
const id1 = floorIds[num1];
const id2 = floorIds[num2];
const map1 = floors.get(id1)?.map;
const map2 = floors.get(id2)?.map;
if (!map1 || !map2) return;
const [w1, h1] = [map1[0].length, map1.length];
const [w2, h2] = [map2[0].length, map2.length];
if (w1 !== w2 || h1 !== h2) return;
generatePair(data, id1, id2, map1, map2, [w1, h1]);
progress.update(i + 1);
});
progress.stop();
return data;
}
export function parseMinamo(data: Map<string, FloorData>): MinamoDataset {
const length = data.size;
const totalCount = Math.round((length * (length - 1)) / 2);
const pairs = choosePair(length, 10000);
console.log(
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
);
const trainData = generateDataset(data, pairs, [...data.keys()]);
const dataset: MinamoDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: trainData
};
return dataset;
}
export function generateAssignedData(
data1: Map<string, FloorData>,
data2: Map<string, FloorData>,
count: [number, number]
): MinamoDataset {
const length = data1.size + data2.size;
const totalCount = data1.size * data2.size;
const count1 = Math.min(count[0], data1.size);
const count2 = Math.min(count[1], data2.size);
const keys1 = [...data1.keys()];
const keys2 = [...data2.keys()];
const choose1 = chooseFrom(keys1, count1);
const trainData: Record<string, MinamoTrainData> = {};
console.log(
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
count1 * count2
} `
);
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(count1 * count2, 0);
let n = 0;
for (const key1 of choose1) {
const choose2 = chooseFrom(keys2, count2);
for (const key2 of choose2) {
const { map: map1 } = data1.get(key1)!;
const { map: map2 } = data2.get(key2)!;
if (!map1 || !map2) continue;
const [w1, h1] = [map1[0].length, map1.length];
const [w2, h2] = [map2[0].length, map2.length];
if (w1 !== w2 || h1 !== h2) continue;
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
n++;
progress.update(n);
}
}
progress.stop();
const dataset: MinamoDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: trainData
};
return dataset;
}

View File

@ -11,3 +11,33 @@ export interface TowerInfo {
floorIds: string[];
config: BaseConfig;
}
export interface GinkaConfig extends BaseConfig {
data: Record<string, string[]>;
}
export interface GinkaTrainData {
text: string[];
map: number[][];
size: [number, number];
}
export interface GinkaDataset {
datasetId: number;
data: Record<string, GinkaTrainData>;
}
export interface MinamoConfig extends BaseConfig {}
export interface MinamoTrainData {
map1: number[][];
map2: number[][];
topoSimilarity: number;
visionSimilarity: number;
size: [number, number];
}
export interface MinamoDataset {
datasetId: number;
data: Record<string, MinamoTrainData>;
}

View File

@ -1,4 +1,5 @@
import json
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
@ -16,6 +17,12 @@ def load_data(path: str):
return data_list
def load_minamo_gan_data(data: list):
res = list()
for one in data:
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
return res
class GinkaDataset(Dataset):
def __init__(self, data_path: str, device, minamo: MinamoModel):
self.data = load_data(data_path) # 自定义数据加载函数
@ -40,4 +47,45 @@ class GinkaDataset(Dataset):
"target_topo_feat": topo_feat,
"target": target,
}
class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path):
self.refer = load_minamo_gan_data(load_data(refer_data_path))
self.data = list().extend(self.refer)
def set_data(self, data: list):
self.data = data.extend(self.refer)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item
map1 = torch.ShortTensor(map1)
map2 = torch.ShortTensor(map2)
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
if review:
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
if review:
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
graph1 = differentiable_convert_to_data(map1)
graph2 = differentiable_convert_to_data(map2)
return (
map1,
map2,
torch.FloatTensor([vis_sim]),
torch.FloatTensor([topo_sim]),
graph1,
graph2
)

View File

@ -1,4 +1,5 @@
import math
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -253,11 +254,8 @@ class GinkaLoss(nn.Module):
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
minamo_loss = (1.0 - minamo_sim).mean()
print(
minamo_loss.item(),
class_loss.item(),
entrance_loss.item(),
count_loss.item()
tqdm.write(
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}"
)
losses = [
@ -267,7 +265,4 @@ class GinkaLoss(nn.Module):
count_loss * self.weight[3]
]
# 梯度归一化
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(scaled_losses)
return total_loss, sum(losses)
return sum(losses)

View File

@ -78,7 +78,7 @@ def train():
_, output_softmax = model(feat_vec)
# 计算损失
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
losses.backward()
@ -111,7 +111,7 @@ def train():
print(torch.argmax(output, dim=1)[0])
# 计算损失
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
loss_val += losses.item()
avg_val_loss = loss_val / len(dataloader_val)

305
ginka/train_gan.py Normal file
View File

@ -0,0 +1,305 @@
import argparse
import socket
import struct
import os
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import cv2
import numpy as np
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset, MinamoGANDataset
from minamo.model.model import MinamoModel
from minamo.model.loss import MinamoLoss
from shared.image import matrix_to_image_cv
BATCH_SIZE = 32
EPOCHS_GINKA = 30
EPOCHS_MINAMO = 10
SOCKET_PATH = "./tmp/ginka_uds"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
os.makedirs("tmp", exist_ok=True)
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default='ginka-eval.json')
parser.add_argument("--from_cycle", type=int, default=2)
parser.add_argument("--to_cycle", type=int, default=100)
args = parser.parse_args()
return args
def parse_ginka_batch(batch):
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
return target, target_vision_feat, target_topo_feat, feat_vec
def parse_minamo_batch(batch):
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
map1 = map1.to(device) # 转为 [B, C, H, W]
map2 = map2.to(device)
topo_simi = topo_simi.to(device)
vision_simi = vision_simi.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
return map1, map2, vision_simi, topo_simi, graph1, graph2
def send_all(sock, data):
total_sent = 0
while total_sent < len(data):
sent = sock.send(data[total_sent:])
if sent == 0:
raise RuntimeError("Socket connection broken")
total_sent += sent
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
# 数据通讯 node 输出协议,单位字节:
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
# 1*tc - Compare count for every map tensor delivered.
# 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
# N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
_, _, H, W = maps.shape
tc_buf = sock.recv(2)
rc_buf = sock.recv(2)
tc = struct.unpack('>h', tc_buf)[0]
rc = struct.unpack('>h', rc_buf)[0]
count_buf = sock.recv(1 * tc)
count: list = struct.unpack(f">{tc}b", count_buf)[0]
N = sum(count)
sim_buf = sock.recv(2 * 4 * (N + rc))
com_buf = sock.recv(N * 1 * H * W)
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes()
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0]
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0]
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list()
res = list()
flatten_idx = 0
# 读取当前这一轮生成器的数据
for idx in range(N):
com_count = count[idx]
for i in range(com_count):
com_start = flatten_idx * H * W
com_end = (flatten_idx + 1) * H * W
vis_sim = sim[flatten_idx * 2]
topo_sim = sim[flatten_idx * 2 + 1]
com_data = com[com_start:com_end]
flatten_idx += 1
com_map = np.fromiter(com_data, np.int8).view(H, W)
# map1, map2, vision_similarity, topo_similarity, is_review
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
return res
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
ginka = GinkaModel()
ginka.to(device)
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
minamo.eval()
# 准备数据集
ginka_dataset = GinkaDataset(args.train, device, minamo)
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
minamo_dataset = MinamoGANDataset()
minamo_dataset_val = MinamoGANDataset()
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
# 设定优化器与调度器
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
# 与 node 端通讯
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(SOCKET_PATH)
server.listen(1)
if args.resume:
data = torch.load(args.from_state, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
if args.load_optim:
optimizer_ginka.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
else:
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion_ginka.weight[0] = 0.0
for cycle in tqdm(range(args.from_cycle, args.to_cycle)):
# -------------------- 训练生成器
gen_list: np.ndarray = np.empty(np.int8)
prob_list: np.ndarray = np.empty(np.float32)
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"):
ginka.train()
minamo.eval()
total_loss = 0
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
if not args.resume and epoch == 10:
criterion_ginka.weight[0] = 0.5
for batch in ginka_dataloader:
# 数据迁移到设备
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
# 前向传播
optimizer_ginka.zero_grad()
_, output_softmax = ginka(feat_vec)
# 计算损失
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
losses.backward()
optimizer_ginka.step()
total_loss += losses.item()
avg_loss = total_loss / len(ginka_dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler_ginka.step()
if (epoch + 1) % 5 == 0:
loss_val = 0
ginka.eval()
idx = 0
with torch.no_grad():
for batch in ginka_dataloader_val:
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
output, output_softmax = ginka(feat_vec)
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
loss_val += losses.item()
if epoch + 1 == EPOCHS_GINKA:
# 最后一次验证的时候顺带生成图片
prob = output_softmax.cpu().numpy()
np.concatenate((prob_list, prob), axis=1)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
gen_list = np.concatenate((gen_list, map_matrix), axis=1)
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
idx += 1
avg_val_loss = loss_val / len(ginka_dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
tqdm.write(f"Cycle {cycle} Ginka train ended.")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka.pth")
# -------------------- 生成 Minamo 的训练数据
# 数据通讯 python 输出协议,单位字节:
# 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
N, H, W = gen_list.shape
gen_bytes = gen_list.astype(np.int8).tobytes()
buf = bytearray()
buf.extend(struct.pack('>h', N)) # Tensor count
buf.extend(struct.pack('>b', H)) # Map height
buf.extend(struct.pack('>b', W)) # Map width
buf.extend(gen_bytes) # Map tensor
server.sendall(buf)
data = parse_minamo_data(server, prob_list)
minamo_dataset.set_data(data)
# -------------------- 训练判别器
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
ginka.eval()
minamo.train()
total_loss = 0
for batch in minamo_dataloader:
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
if map1.shape[0] == 1:
continue
# 前向传播
optimizer_minamo.zero_grad()
vision_feat1, topo_feat1 = minamo(map1, graph1)
vision_feat2, topo_feat2 = minamo(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
# 反向传播
loss.backward()
optimizer_minamo.step()
total_loss += loss.item()
ave_loss = total_loss / len(minamo_dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
scheduler_minamo.step()
# 每十轮推理一次验证集
if (epoch + 1) % 5 == 0:
minamo.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(minamo_dataloader_val, leave=False):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(minamo_dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": minamo.state_dict()
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
tqdm.write(f"Cycle {cycle} Minamo train ended.")
torch.save({
"model_state": minamo.state_dict()
}, f"result/ginka.pth")
print("Train ended.")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -10,59 +10,11 @@ from minamo.model.model import MinamoModel
from .dataset import GinkaDataset
from .model.loss import GinkaLoss
from .model.model import GinkaModel
from shared.image import matrix_to_image_cv
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs('result/ginka_img', exist_ok=True)
def blend_alpha(bg, fg, alpha):
""" 使用 alpha 通道混合前景图块和背景图 """
for c in range(3): # 只混合 RGB 三个通道
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
return bg
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
"""
使用OpenCV加速的版本适合大尺寸地图
:param map_matrix: [H, W] 的numpy数组
:param tile_set: 字典 {tile_id: cv2图像BGR格式}
:param tile_size: 图块边长像素
"""
H, W = map_matrix.shape # 获取地图尺寸
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
# 遍历地图矩阵
for row in range(H):
for col in range(W):
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
x, y = col * tile_size, row * tile_size # 计算像素位置
# 先绘制地面0
if '0' in tile_set:
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
if tile_index == '11':
if row == 0:
tile_index = '11_1'
elif row == W - 1:
tile_index = '11_3'
elif col == 0:
tile_index = '11_2'
elif col == H - 1:
tile_index = '11_4'
# 叠加其他透明图块
if tile_index in tile_set and tile_index != 0:
tile_rgba = tile_set[tile_index]
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
# 混合当前图块到背景
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
)
return canvas
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()

50
shared/image.py Normal file
View File

@ -0,0 +1,50 @@
import numpy as np
def blend_alpha(bg, fg, alpha):
""" 使用 alpha 通道混合前景图块和背景图 """
for c in range(3): # 只混合 RGB 三个通道
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
return bg
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
"""
使用OpenCV加速的版本适合大尺寸地图
:param map_matrix: [H, W] 的numpy数组
:param tile_set: 字典 {tile_id: cv2图像BGR格式}
:param tile_size: 图块边长像素
"""
H, W = map_matrix.shape # 获取地图尺寸
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
# 遍历地图矩阵
for row in range(H):
for col in range(W):
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
x, y = col * tile_size, row * tile_size # 计算像素位置
# 先绘制地面0
if '0' in tile_set:
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
if tile_index == '11':
if row == 0:
tile_index = '11_1'
elif row == W - 1:
tile_index = '11_3'
elif col == 0:
tile_index = '11_2'
elif col == H - 1:
tile_index = '11_4'
# 叠加其他透明图块
if tile_index in tile_set and tile_index != 0:
tile_rgba = tile_set[tile_index]
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
# 混合当前图块到背景
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
)
return canvas