feat: 支持流式传输,改进数据集

This commit is contained in:
unanmed 2025-04-02 17:57:02 +08:00
parent 8dea79a9f0
commit 325eb599c3
6 changed files with 140 additions and 65 deletions

View File

@ -8,7 +8,7 @@ const [refer] = process.argv.slice(2);
let id = 0;
function readMap(count: number, buffer: Buffer, h: number, w: number) {
function readMap(count: number, arr: number[], h: number, w: number) {
const area = w * h;
const maps: number[][][] = Array.from<number[][]>({
@ -19,7 +19,7 @@ function readMap(count: number, buffer: Buffer, h: number, w: number) {
});
});
buffer.subarray(4).forEach((v, i) => {
arr.forEach((v, i) => {
const n = Math.floor(i / area);
const y = Math.floor((i % area) / w);
const x = i % w;
@ -35,7 +35,7 @@ function generateGANData(
map: number[][]
) {
const id2 = `$${id++}`;
const toTrain = chooseFrom(keys, 4);
const toTrain = chooseFrom(keys, 30);
const data = toTrain.map<MinamoTrainData[]>(v => {
const floor = refer.get(v);
if (!floor) return [];
@ -48,28 +48,67 @@ function generateGANData(
return data.flat();
}
const enum ReceiverStatus {
Header,
Content
}
class DataReceiver {
static active?: DataReceiver
/** 接收状态 */
private status: ReceiverStatus = ReceiverStatus.Header;
private received: number[] = []
private count: number = 0;
private h: number = 0;
private w: number = 0;
receive(buf: Buffer): [number[][][], number, number, number] | null {
// 数据通讯 node 输入协议,单位字节:
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
switch (this.status) {
case ReceiverStatus.Header:
this.count = buf.readInt16BE();
this.h = buf.readInt8(2);
this.w = buf.readInt8(3);
this.received.push(...buf.subarray(4));
this.status = ReceiverStatus.Content;
break;
case ReceiverStatus.Content:
this.received.push(...buf);
break
}
if (this.received.length === this.count * this.h * this.w) {
delete DataReceiver.active;
return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w];
} else {
return null;
}
}
static check(buf: Buffer) {
if (this.active) {
return this.active.receive(buf);
} else {
this.active = new DataReceiver();
return this.active.receive(buf);
}
}
}
(async () => {
const referTower = await readOne(refer);
const keys = [...referTower.keys()];
const client = createConnection(SOCKET_FILE, () => {
console.log(`UDS IPC connected successfully.`);
// 发送四字节数据表示连接成功
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
});
client.on('data', buffer => {
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化
// 数据通讯 node 输入协议,单位字节:
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
const count = buffer.readInt16BE();
if (buffer.length - 4 !== count * 32 * 32) {
client.write(`ERROR: byte length not match.`);
return [];
}
const h = buffer.readInt8(2);
const w = buffer.readInt8(3);
const map = readMap(count, buffer, h, w);
const data = DataReceiver.check(buffer);
if (!data) return;
const [map, count, h, w] = data;
const simData = map.map(v => generateGANData(keys, referTower, v));
const rc = 0;
const compareData = simData.flat();
@ -83,12 +122,23 @@ function generateGANData(
const toSend = Buffer.alloc(
2 + // Tensor count
2 + // Review count
count + // Compare count
2 * (count + rc) + // Similarity data
1 * count + // Compare count
2 * 4 * (compareData.length + rc) + // Similarity data
compareData.length * 1 * h * w + // Compare map
rc * 2 * h * w, // Review map
0
);
console.log(
2,
2,
count,
2 * 4 * (compareData.length + rc),
compareData.length * 1 * h * w,
rc * 2 * h * w,
compareData.length,
rc
);
let offset = 0;
toSend.writeInt16BE(count); // Tensor count
toSend.writeInt16BE(0, 2); // Review count
@ -98,9 +148,11 @@ function generateGANData(
simData.map(v => v.length),
offset
);
offset += count;
offset += 1 * count;
// Similarity data
compareData.forEach(v => {
// console.log(v.visionSimilarity, v.topoSimilarity);
toSend.writeFloatBE(v.visionSimilarity, offset);
offset += 4;
toSend.writeFloatBE(v.topoSimilarity, offset);
@ -108,15 +160,17 @@ function generateGANData(
});
// Compare map
toSend.set(
compareData.map(v => v.map1).flat(2),
new Uint8Array(compareData.map(v => v.map1).flat(3)),
offset // Set from Compare map
);
offset += compareData.length * 1 * h * w;
// Review map
toSend.set(
reviewData.map(v => [v.map1, v.map2]).flat(3),
offset // Set from last chunk
);
if (reviewData.length > 0) {
// Review map
toSend.set(
new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)),
offset // Set from last chunk
);
}
client.write(toSend);
});

View File

@ -136,7 +136,7 @@ function generateSimilarData(id: string, map: number[][]) {
// 生成最多两个微调地图
const width = map[0].length;
const height = map.length;
const num = Math.floor(Math.random() * 1);
const num = Math.floor(Math.random() * 2);
const res: [id: string, data: MinamoTrainData][] = [];
for (let i = 0; i < num; i++) {

View File

@ -51,10 +51,14 @@ class GinkaDataset(Dataset):
class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path):
self.refer = load_minamo_gan_data(load_data(refer_data_path))
self.data = list().extend(self.refer)
self.data = list()
self.data.extend(random.sample(self.refer, 1000))
def set_data(self, data: list):
self.data = data.extend(self.refer)
self.data.clear()
self.data.extend(data)
k = min(len(data) / 4, len(self.refer))
self.data.extend(random.sample(self.refer, int(k)))
def __len__(self):
return len(self.data)
@ -63,8 +67,8 @@ class MinamoGANDataset(Dataset):
item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item
map1 = torch.ShortTensor(map1)
map2 = torch.ShortTensor(map2)
map1 = torch.LongTensor(map1)
map2 = torch.LongTensor(map2)
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
if review:
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]

View File

@ -255,7 +255,7 @@ class GinkaLoss(nn.Module):
minamo_loss = (1.0 - minamo_sim).mean()
tqdm.write(
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}"
f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}"
)
losses = [

View File

@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv
BATCH_SIZE = 32
EPOCHS_GINKA = 30
EPOCHS_MINAMO = 10
EPOCHS_MINAMO = 15
SOCKET_PATH = "./tmp/ginka_uds"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -64,6 +64,16 @@ def send_all(sock, data):
raise RuntimeError("Socket connection broken")
total_sent += sent
def recv_all(sock: socket.socket, length: int):
"""循环接收直到获得指定长度的数据"""
data = bytes()
while len(data) < length:
packet = sock.recv(length - len(data)) # 只请求剩余部分
if not packet:
raise ConnectionError("连接中断")
data += packet
return data
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
# 数据通讯 node 输出协议,单位字节:
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
@ -75,21 +85,21 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
rc_buf = sock.recv(2)
tc = struct.unpack('>h', tc_buf)[0]
rc = struct.unpack('>h', rc_buf)[0]
count_buf = sock.recv(1 * tc)
count: list = struct.unpack(f">{tc}b", count_buf)[0]
count_buf = recv_all(sock, 1 * tc)
count: list = struct.unpack(f">{tc}b", count_buf)
N = sum(count)
sim_buf = sock.recv(2 * 4 * (N + rc))
com_buf = sock.recv(N * 1 * H * W)
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes()
sim_buf = recv_all(sock, 2 * 4 * (N + rc))
com_buf = recv_all(sock, N * 1 * H * W)
review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0]
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0]
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list()
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
res = list()
flatten_idx = 0
# 读取当前这一轮生成器的数据
for idx in range(N):
for idx in range(tc):
com_count = count[idx]
for i in range(com_count):
com_start = flatten_idx * H * W
@ -98,7 +108,7 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
topo_sim = sim[flatten_idx * 2 + 1]
com_data = com[com_start:com_end]
flatten_idx += 1
com_map = np.fromiter(com_data, np.int8).view(H, W)
com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
# map1, map2, vision_similarity, topo_similarity, is_review
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
@ -107,7 +117,7 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
args = parse_arguments()
ginka = GinkaModel()
ginka.to(device)
@ -119,8 +129,8 @@ def train():
# 准备数据集
ginka_dataset = GinkaDataset(args.train, device, minamo)
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
minamo_dataset = MinamoGANDataset()
minamo_dataset_val = MinamoGANDataset()
minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
@ -132,7 +142,7 @@ def train():
criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss()
# 用于生成图片
@ -142,10 +152,16 @@ def train():
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
# 与 node 端通讯
if os.path.exists(SOCKET_PATH):
os.remove(SOCKET_PATH)
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(SOCKET_PATH)
server.listen(1)
print("Waiting for client connection...")
conn, _ = server.accept()
print("Client connected.")
if args.resume:
data = torch.load(args.from_state, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
@ -157,11 +173,11 @@ def train():
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion_ginka.weight[0] = 0.0
for cycle in tqdm(range(args.from_cycle, args.to_cycle)):
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
# -------------------- 训练生成器
gen_list: np.ndarray = np.empty(np.int8)
prob_list: np.ndarray = np.empty(np.float32)
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"):
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"):
ginka.train()
minamo.eval()
total_loss = 0
@ -170,7 +186,7 @@ def train():
if not args.resume and epoch == 10:
criterion_ginka.weight[0] = 0.5
for batch in ginka_dataloader:
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
# 数据迁移到设备
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
# 前向传播
@ -187,14 +203,14 @@ def train():
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler_ginka.step()
scheduler_ginka.step(epoch + 1)
if (epoch + 1) % 5 == 0:
loss_val = 0
ginka.eval()
idx = 0
with torch.no_grad():
for batch in ginka_dataloader_val:
for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
output, output_softmax = ginka(feat_vec)
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
@ -202,9 +218,9 @@ def train():
if epoch + 1 == EPOCHS_GINKA:
# 最后一次验证的时候顺带生成图片
prob = output_softmax.cpu().numpy()
np.concatenate((prob_list, prob), axis=1)
prob_list = np.concatenate((prob_list, prob), axis=0)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
gen_list = np.concatenate((gen_list, map_matrix), axis=1)
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
@ -232,8 +248,8 @@ def train():
buf.extend(struct.pack('>b', H)) # Map height
buf.extend(struct.pack('>b', W)) # Map width
buf.extend(gen_bytes) # Map tensor
server.sendall(buf)
data = parse_minamo_data(server, prob_list)
conn.sendall(buf)
data = parse_minamo_data(conn, prob_list)
minamo_dataset.set_data(data)
# -------------------- 训练判别器
@ -242,7 +258,7 @@ def train():
minamo.train()
total_loss = 0
for batch in minamo_dataloader:
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
if map1.shape[0] == 1:
@ -265,16 +281,16 @@ def train():
total_loss += loss.item()
ave_loss = total_loss / len(minamo_dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
scheduler_minamo.step()
scheduler_minamo.step(epoch + 1)
# 每十轮推理一次验证集
if (epoch + 1) % 5 == 0:
minamo.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(minamo_dataloader_val, leave=False):
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
vision_feat1, topo_feat1 = minamo(map1_val, graph1)

View File

@ -1,4 +1,5 @@
import torch.nn as nn
from tqdm import tqdm
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0.2, topo_weight=0.8):
@ -9,7 +10,7 @@ class MinamoLoss(nn.Module):
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}")
vis_loss = self.loss(vis_pred, vis_true)
topo_loss = self.loss(topo_pred, topo_true)
# print(vis_loss.item(), topo_loss.item())