diff --git a/data/src/gan.ts b/data/src/gan.ts index 05ab1ca..2c2dc53 100644 --- a/data/src/gan.ts +++ b/data/src/gan.ts @@ -8,7 +8,7 @@ const [refer] = process.argv.slice(2); let id = 0; -function readMap(count: number, buffer: Buffer, h: number, w: number) { +function readMap(count: number, arr: number[], h: number, w: number) { const area = w * h; const maps: number[][][] = Array.from({ @@ -19,7 +19,7 @@ function readMap(count: number, buffer: Buffer, h: number, w: number) { }); }); - buffer.subarray(4).forEach((v, i) => { + arr.forEach((v, i) => { const n = Math.floor(i / area); const y = Math.floor((i % area) / w); const x = i % w; @@ -35,41 +35,80 @@ function generateGANData( map: number[][] ) { const id2 = `$${id++}`; - const toTrain = chooseFrom(keys, 4); + const toTrain = chooseFrom(keys, 30); const data = toTrain.map(v => { const floor = refer.get(v); if (!floor) return []; const size1: [number, number] = [floor.map[0].length, floor.map.length]; const size2: [number, number] = [map[0].length, map.length]; - if (size1[0] !== size2[0] || size1[1] !== size2[1]) return []; + if (size1[0] !== size2[0] || size1[1] !== size2[1]) return []; return generateTrainData(v, id2, floor.map, map, size1); }); return data.flat(); } +const enum ReceiverStatus { + Header, + Content +} + +class DataReceiver { + static active?: DataReceiver + /** 接收状态 */ + private status: ReceiverStatus = ReceiverStatus.Header; + + private received: number[] = [] + private count: number = 0; + private h: number = 0; + private w: number = 0; + + receive(buf: Buffer): [number[][][], number, number, number] | null { + // 数据通讯 node 输入协议,单位字节: + // 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type. + switch (this.status) { + case ReceiverStatus.Header: + this.count = buf.readInt16BE(); + this.h = buf.readInt8(2); + this.w = buf.readInt8(3); + this.received.push(...buf.subarray(4)); + this.status = ReceiverStatus.Content; + break; + case ReceiverStatus.Content: + this.received.push(...buf); + break + } + if (this.received.length === this.count * this.h * this.w) { + delete DataReceiver.active; + return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w]; + } else { + return null; + } + } + + static check(buf: Buffer) { + if (this.active) { + return this.active.receive(buf); + } else { + this.active = new DataReceiver(); + return this.active.receive(buf); + } + } +} + (async () => { const referTower = await readOne(refer); const keys = [...referTower.keys()]; const client = createConnection(SOCKET_FILE, () => { console.log(`UDS IPC connected successfully.`); - // 发送四字节数据表示连接成功 - client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00])); }); client.on('data', buffer => { - // 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化 - // 数据通讯 node 输入协议,单位字节: - // 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type. - const count = buffer.readInt16BE(); - if (buffer.length - 4 !== count * 32 * 32) { - client.write(`ERROR: byte length not match.`); - return []; - } - const h = buffer.readInt8(2); - const w = buffer.readInt8(3); - const map = readMap(count, buffer, h, w); + const data = DataReceiver.check(buffer); + if (!data) return; + + const [map, count, h, w] = data; const simData = map.map(v => generateGANData(keys, referTower, v)); const rc = 0; const compareData = simData.flat(); @@ -83,12 +122,23 @@ function generateGANData( const toSend = Buffer.alloc( 2 + // Tensor count 2 + // Review count - count + // Compare count - 2 * (count + rc) + // Similarity data + 1 * count + // Compare count + 2 * 4 * (compareData.length + rc) + // Similarity data compareData.length * 1 * h * w + // Compare map rc * 2 * h * w, // Review map 0 ); + console.log( + 2, + 2, + count, + 2 * 4 * (compareData.length + rc), + compareData.length * 1 * h * w, + rc * 2 * h * w, + compareData.length, + rc + ); + let offset = 0; toSend.writeInt16BE(count); // Tensor count toSend.writeInt16BE(0, 2); // Review count @@ -98,9 +148,11 @@ function generateGANData( simData.map(v => v.length), offset ); - offset += count; + offset += 1 * count; // Similarity data compareData.forEach(v => { + // console.log(v.visionSimilarity, v.topoSimilarity); + toSend.writeFloatBE(v.visionSimilarity, offset); offset += 4; toSend.writeFloatBE(v.topoSimilarity, offset); @@ -108,15 +160,17 @@ function generateGANData( }); // Compare map toSend.set( - compareData.map(v => v.map1).flat(2), + new Uint8Array(compareData.map(v => v.map1).flat(3)), offset // Set from Compare map ); offset += compareData.length * 1 * h * w; - // Review map - toSend.set( - reviewData.map(v => [v.map1, v.map2]).flat(3), - offset // Set from last chunk - ); + if (reviewData.length > 0) { + // Review map + toSend.set( + new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)), + offset // Set from last chunk + ); + } client.write(toSend); }); diff --git a/data/src/process/minamo.ts b/data/src/process/minamo.ts index 7edd800..40471b4 100644 --- a/data/src/process/minamo.ts +++ b/data/src/process/minamo.ts @@ -136,7 +136,7 @@ function generateSimilarData(id: string, map: number[][]) { // 生成最多两个微调地图 const width = map[0].length; const height = map.length; - const num = Math.floor(Math.random() * 1); + const num = Math.floor(Math.random() * 2); const res: [id: string, data: MinamoTrainData][] = []; for (let i = 0; i < num; i++) { diff --git a/ginka/dataset.py b/ginka/dataset.py index 2c44aa9..bca3841 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -51,10 +51,14 @@ class GinkaDataset(Dataset): class MinamoGANDataset(Dataset): def __init__(self, refer_data_path): self.refer = load_minamo_gan_data(load_data(refer_data_path)) - self.data = list().extend(self.refer) + self.data = list() + self.data.extend(random.sample(self.refer, 1000)) def set_data(self, data: list): - self.data = data.extend(self.refer) + self.data.clear() + self.data.extend(data) + k = min(len(data) / 4, len(self.refer)) + self.data.extend(random.sample(self.refer, int(k))) def __len__(self): return len(self.data) @@ -63,8 +67,8 @@ class MinamoGANDataset(Dataset): item = self.data[idx] map1, map2, vis_sim, topo_sim, review = item - map1 = torch.ShortTensor(map1) - map2 = torch.ShortTensor(map2) + map1 = torch.LongTensor(map1) + map2 = torch.LongTensor(map2) # 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换 if review: map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W] diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 8d72156..272e6d4 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -255,7 +255,7 @@ class GinkaLoss(nn.Module): minamo_loss = (1.0 - minamo_sim).mean() tqdm.write( - f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}" + f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}" ) losses = [ diff --git a/ginka/train_gan.py b/ginka/train_gan.py index 52ed152..416905d 100644 --- a/ginka/train_gan.py +++ b/ginka/train_gan.py @@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 32 EPOCHS_GINKA = 30 -EPOCHS_MINAMO = 10 +EPOCHS_MINAMO = 15 SOCKET_PATH = "./tmp/ginka_uds" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -64,6 +64,16 @@ def send_all(sock, data): raise RuntimeError("Socket connection broken") total_sent += sent +def recv_all(sock: socket.socket, length: int): + """循环接收直到获得指定长度的数据""" + data = bytes() + while len(data) < length: + packet = sock.recv(length - len(data)) # 只请求剩余部分 + if not packet: + raise ConnectionError("连接中断") + data += packet + return data + def parse_minamo_data(sock: socket.socket, maps: np.ndarray): # 数据通讯 node 输出协议,单位字节: # 2 - Tensor count; 2 - Review count. Review is right behind train data; @@ -75,21 +85,21 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray): rc_buf = sock.recv(2) tc = struct.unpack('>h', tc_buf)[0] rc = struct.unpack('>h', rc_buf)[0] - count_buf = sock.recv(1 * tc) - count: list = struct.unpack(f">{tc}b", count_buf)[0] + count_buf = recv_all(sock, 1 * tc) + count: list = struct.unpack(f">{tc}b", count_buf) N = sum(count) - sim_buf = sock.recv(2 * 4 * (N + rc)) - com_buf = sock.recv(N * 1 * H * W) - review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes() + sim_buf = recv_all(sock, 2 * 4 * (N + rc)) + com_buf = recv_all(sock, N * 1 * H * W) + review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes() - sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0] - com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0] - review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list() + sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf) + com = struct.unpack(f">{N * 1 * H * W}b", com_buf) + review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list() res = list() flatten_idx = 0 # 读取当前这一轮生成器的数据 - for idx in range(N): + for idx in range(tc): com_count = count[idx] for i in range(com_count): com_start = flatten_idx * H * W @@ -98,16 +108,16 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray): topo_sim = sim[flatten_idx * 2 + 1] com_data = com[com_start:com_end] flatten_idx += 1 - com_map = np.fromiter(com_data, np.int8).view(H, W) + com_map = np.array(com_data, dtype=np.int8).reshape(H, W) # map1, map2, vision_similarity, topo_similarity, is_review res.append((maps[idx], com_map, vis_sim, topo_sim, False)) - - return res + + return res def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json') + args = parse_arguments() ginka = GinkaModel() ginka.to(device) @@ -119,8 +129,8 @@ def train(): # 准备数据集 ginka_dataset = GinkaDataset(args.train, device, minamo) ginka_dataset_val = GinkaDataset(args.validate, device, minamo) - minamo_dataset = MinamoGANDataset() - minamo_dataset_val = MinamoGANDataset() + minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json") + minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json") ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True) ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True) minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True) @@ -132,7 +142,7 @@ def train(): criterion_ginka = GinkaLoss(minamo) optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6) criterion_minamo = MinamoLoss() # 用于生成图片 @@ -142,10 +152,16 @@ def train(): tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) # 与 node 端通讯 + if os.path.exists(SOCKET_PATH): + os.remove(SOCKET_PATH) server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server.bind(SOCKET_PATH) server.listen(1) + print("Waiting for client connection...") + conn, _ = server.accept() + print("Client connected.") + if args.resume: data = torch.load(args.from_state, map_location=device) ginka.load_state_dict(data["model_state"], strict=False) @@ -157,11 +173,11 @@ def train(): # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 criterion_ginka.weight[0] = 0.0 - for cycle in tqdm(range(args.from_cycle, args.to_cycle)): + for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"): # -------------------- 训练生成器 - gen_list: np.ndarray = np.empty(np.int8) - prob_list: np.ndarray = np.empty(np.float32) - for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"): + gen_list: np.ndarray = np.empty((0, 13, 13), np.int8) + prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32) + for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"): ginka.train() minamo.eval() total_loss = 0 @@ -170,7 +186,7 @@ def train(): if not args.resume and epoch == 10: criterion_ginka.weight[0] = 0.5 - for batch in ginka_dataloader: + for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"): # 数据迁移到设备 target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) # 前向传播 @@ -187,14 +203,14 @@ def train(): tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") # 学习率调整 - scheduler_ginka.step() + scheduler_ginka.step(epoch + 1) if (epoch + 1) % 5 == 0: loss_val = 0 ginka.eval() idx = 0 with torch.no_grad(): - for batch in ginka_dataloader_val: + for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"): target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) output, output_softmax = ginka(feat_vec) losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat) @@ -202,9 +218,9 @@ def train(): if epoch + 1 == EPOCHS_GINKA: # 最后一次验证的时候顺带生成图片 prob = output_softmax.cpu().numpy() - np.concatenate((prob_list, prob), axis=1) + prob_list = np.concatenate((prob_list, prob), axis=0) map_matrix = torch.argmax(output, dim=1).cpu().numpy() - gen_list = np.concatenate((gen_list, map_matrix), axis=1) + gen_list = np.concatenate((gen_list, map_matrix), axis=0) for matrix in map_matrix: image = matrix_to_image_cv(matrix, tile_dict) cv2.imwrite(f"result/ginka_img/{idx}.png", image) @@ -232,8 +248,8 @@ def train(): buf.extend(struct.pack('>b', H)) # Map height buf.extend(struct.pack('>b', W)) # Map width buf.extend(gen_bytes) # Map tensor - server.sendall(buf) - data = parse_minamo_data(server, prob_list) + conn.sendall(buf) + data = parse_minamo_data(conn, prob_list) minamo_dataset.set_data(data) # -------------------- 训练判别器 @@ -242,7 +258,7 @@ def train(): minamo.train() total_loss = 0 - for batch in minamo_dataloader: + for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"): map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch) if map1.shape[0] == 1: @@ -265,16 +281,16 @@ def train(): total_loss += loss.item() ave_loss = total_loss / len(minamo_dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}") - scheduler_minamo.step() + scheduler_minamo.step(epoch + 1) # 每十轮推理一次验证集 if (epoch + 1) % 5 == 0: minamo.eval() val_loss = 0 with torch.no_grad(): - for val_batch in tqdm(minamo_dataloader_val, leave=False): + for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"): map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch) vision_feat1, topo_feat1 = minamo(map1_val, graph1) diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 6fb1719..ffcf575 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,4 +1,5 @@ import torch.nn as nn +from tqdm import tqdm class MinamoLoss(nn.Module): def __init__(self, vision_weight=0.2, topo_weight=0.8): @@ -9,7 +10,7 @@ class MinamoLoss(nn.Module): def forward(self, vis_pred, topo_pred, vis_true, topo_true): # print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape) - # print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item()) + # tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}") vis_loss = self.loss(vis_pred, vis_true) topo_loss = self.loss(topo_pred, topo_true) # print(vis_loss.item(), topo_loss.item())