feat: 支持流式传输,改进数据集

This commit is contained in:
unanmed 2025-04-02 17:57:02 +08:00
parent 8dea79a9f0
commit 325eb599c3
6 changed files with 140 additions and 65 deletions

View File

@ -8,7 +8,7 @@ const [refer] = process.argv.slice(2);
let id = 0; let id = 0;
function readMap(count: number, buffer: Buffer, h: number, w: number) { function readMap(count: number, arr: number[], h: number, w: number) {
const area = w * h; const area = w * h;
const maps: number[][][] = Array.from<number[][]>({ const maps: number[][][] = Array.from<number[][]>({
@ -19,7 +19,7 @@ function readMap(count: number, buffer: Buffer, h: number, w: number) {
}); });
}); });
buffer.subarray(4).forEach((v, i) => { arr.forEach((v, i) => {
const n = Math.floor(i / area); const n = Math.floor(i / area);
const y = Math.floor((i % area) / w); const y = Math.floor((i % area) / w);
const x = i % w; const x = i % w;
@ -35,41 +35,80 @@ function generateGANData(
map: number[][] map: number[][]
) { ) {
const id2 = `$${id++}`; const id2 = `$${id++}`;
const toTrain = chooseFrom(keys, 4); const toTrain = chooseFrom(keys, 30);
const data = toTrain.map<MinamoTrainData[]>(v => { const data = toTrain.map<MinamoTrainData[]>(v => {
const floor = refer.get(v); const floor = refer.get(v);
if (!floor) return []; if (!floor) return [];
const size1: [number, number] = [floor.map[0].length, floor.map.length]; const size1: [number, number] = [floor.map[0].length, floor.map.length];
const size2: [number, number] = [map[0].length, map.length]; const size2: [number, number] = [map[0].length, map.length];
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return []; if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
return generateTrainData(v, id2, floor.map, map, size1); return generateTrainData(v, id2, floor.map, map, size1);
}); });
return data.flat(); return data.flat();
} }
const enum ReceiverStatus {
Header,
Content
}
class DataReceiver {
static active?: DataReceiver
/** 接收状态 */
private status: ReceiverStatus = ReceiverStatus.Header;
private received: number[] = []
private count: number = 0;
private h: number = 0;
private w: number = 0;
receive(buf: Buffer): [number[][][], number, number, number] | null {
// 数据通讯 node 输入协议,单位字节:
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
switch (this.status) {
case ReceiverStatus.Header:
this.count = buf.readInt16BE();
this.h = buf.readInt8(2);
this.w = buf.readInt8(3);
this.received.push(...buf.subarray(4));
this.status = ReceiverStatus.Content;
break;
case ReceiverStatus.Content:
this.received.push(...buf);
break
}
if (this.received.length === this.count * this.h * this.w) {
delete DataReceiver.active;
return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w];
} else {
return null;
}
}
static check(buf: Buffer) {
if (this.active) {
return this.active.receive(buf);
} else {
this.active = new DataReceiver();
return this.active.receive(buf);
}
}
}
(async () => { (async () => {
const referTower = await readOne(refer); const referTower = await readOne(refer);
const keys = [...referTower.keys()]; const keys = [...referTower.keys()];
const client = createConnection(SOCKET_FILE, () => { const client = createConnection(SOCKET_FILE, () => {
console.log(`UDS IPC connected successfully.`); console.log(`UDS IPC connected successfully.`);
// 发送四字节数据表示连接成功
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
}); });
client.on('data', buffer => { client.on('data', buffer => {
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化 const data = DataReceiver.check(buffer);
// 数据通讯 node 输入协议,单位字节: if (!data) return;
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
const count = buffer.readInt16BE(); const [map, count, h, w] = data;
if (buffer.length - 4 !== count * 32 * 32) {
client.write(`ERROR: byte length not match.`);
return [];
}
const h = buffer.readInt8(2);
const w = buffer.readInt8(3);
const map = readMap(count, buffer, h, w);
const simData = map.map(v => generateGANData(keys, referTower, v)); const simData = map.map(v => generateGANData(keys, referTower, v));
const rc = 0; const rc = 0;
const compareData = simData.flat(); const compareData = simData.flat();
@ -83,12 +122,23 @@ function generateGANData(
const toSend = Buffer.alloc( const toSend = Buffer.alloc(
2 + // Tensor count 2 + // Tensor count
2 + // Review count 2 + // Review count
count + // Compare count 1 * count + // Compare count
2 * (count + rc) + // Similarity data 2 * 4 * (compareData.length + rc) + // Similarity data
compareData.length * 1 * h * w + // Compare map compareData.length * 1 * h * w + // Compare map
rc * 2 * h * w, // Review map rc * 2 * h * w, // Review map
0 0
); );
console.log(
2,
2,
count,
2 * 4 * (compareData.length + rc),
compareData.length * 1 * h * w,
rc * 2 * h * w,
compareData.length,
rc
);
let offset = 0; let offset = 0;
toSend.writeInt16BE(count); // Tensor count toSend.writeInt16BE(count); // Tensor count
toSend.writeInt16BE(0, 2); // Review count toSend.writeInt16BE(0, 2); // Review count
@ -98,9 +148,11 @@ function generateGANData(
simData.map(v => v.length), simData.map(v => v.length),
offset offset
); );
offset += count; offset += 1 * count;
// Similarity data // Similarity data
compareData.forEach(v => { compareData.forEach(v => {
// console.log(v.visionSimilarity, v.topoSimilarity);
toSend.writeFloatBE(v.visionSimilarity, offset); toSend.writeFloatBE(v.visionSimilarity, offset);
offset += 4; offset += 4;
toSend.writeFloatBE(v.topoSimilarity, offset); toSend.writeFloatBE(v.topoSimilarity, offset);
@ -108,15 +160,17 @@ function generateGANData(
}); });
// Compare map // Compare map
toSend.set( toSend.set(
compareData.map(v => v.map1).flat(2), new Uint8Array(compareData.map(v => v.map1).flat(3)),
offset // Set from Compare map offset // Set from Compare map
); );
offset += compareData.length * 1 * h * w; offset += compareData.length * 1 * h * w;
// Review map if (reviewData.length > 0) {
toSend.set( // Review map
reviewData.map(v => [v.map1, v.map2]).flat(3), toSend.set(
offset // Set from last chunk new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)),
); offset // Set from last chunk
);
}
client.write(toSend); client.write(toSend);
}); });

View File

@ -136,7 +136,7 @@ function generateSimilarData(id: string, map: number[][]) {
// 生成最多两个微调地图 // 生成最多两个微调地图
const width = map[0].length; const width = map[0].length;
const height = map.length; const height = map.length;
const num = Math.floor(Math.random() * 1); const num = Math.floor(Math.random() * 2);
const res: [id: string, data: MinamoTrainData][] = []; const res: [id: string, data: MinamoTrainData][] = [];
for (let i = 0; i < num; i++) { for (let i = 0; i < num; i++) {

View File

@ -51,10 +51,14 @@ class GinkaDataset(Dataset):
class MinamoGANDataset(Dataset): class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path): def __init__(self, refer_data_path):
self.refer = load_minamo_gan_data(load_data(refer_data_path)) self.refer = load_minamo_gan_data(load_data(refer_data_path))
self.data = list().extend(self.refer) self.data = list()
self.data.extend(random.sample(self.refer, 1000))
def set_data(self, data: list): def set_data(self, data: list):
self.data = data.extend(self.refer) self.data.clear()
self.data.extend(data)
k = min(len(data) / 4, len(self.refer))
self.data.extend(random.sample(self.refer, int(k)))
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -63,8 +67,8 @@ class MinamoGANDataset(Dataset):
item = self.data[idx] item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item map1, map2, vis_sim, topo_sim, review = item
map1 = torch.ShortTensor(map1) map1 = torch.LongTensor(map1)
map2 = torch.ShortTensor(map2) map2 = torch.LongTensor(map2)
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换 # 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
if review: if review:
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W] map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]

View File

@ -255,7 +255,7 @@ class GinkaLoss(nn.Module):
minamo_loss = (1.0 - minamo_sim).mean() minamo_loss = (1.0 - minamo_sim).mean()
tqdm.write( tqdm.write(
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}" f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}"
) )
losses = [ losses = [

View File

@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv
BATCH_SIZE = 32 BATCH_SIZE = 32
EPOCHS_GINKA = 30 EPOCHS_GINKA = 30
EPOCHS_MINAMO = 10 EPOCHS_MINAMO = 15
SOCKET_PATH = "./tmp/ginka_uds" SOCKET_PATH = "./tmp/ginka_uds"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -64,6 +64,16 @@ def send_all(sock, data):
raise RuntimeError("Socket connection broken") raise RuntimeError("Socket connection broken")
total_sent += sent total_sent += sent
def recv_all(sock: socket.socket, length: int):
"""循环接收直到获得指定长度的数据"""
data = bytes()
while len(data) < length:
packet = sock.recv(length - len(data)) # 只请求剩余部分
if not packet:
raise ConnectionError("连接中断")
data += packet
return data
def parse_minamo_data(sock: socket.socket, maps: np.ndarray): def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
# 数据通讯 node 输出协议,单位字节: # 数据通讯 node 输出协议,单位字节:
# 2 - Tensor count; 2 - Review count. Review is right behind train data; # 2 - Tensor count; 2 - Review count. Review is right behind train data;
@ -75,21 +85,21 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
rc_buf = sock.recv(2) rc_buf = sock.recv(2)
tc = struct.unpack('>h', tc_buf)[0] tc = struct.unpack('>h', tc_buf)[0]
rc = struct.unpack('>h', rc_buf)[0] rc = struct.unpack('>h', rc_buf)[0]
count_buf = sock.recv(1 * tc) count_buf = recv_all(sock, 1 * tc)
count: list = struct.unpack(f">{tc}b", count_buf)[0] count: list = struct.unpack(f">{tc}b", count_buf)
N = sum(count) N = sum(count)
sim_buf = sock.recv(2 * 4 * (N + rc)) sim_buf = recv_all(sock, 2 * 4 * (N + rc))
com_buf = sock.recv(N * 1 * H * W) com_buf = recv_all(sock, N * 1 * H * W)
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes() review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0] sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0] com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list() review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
res = list() res = list()
flatten_idx = 0 flatten_idx = 0
# 读取当前这一轮生成器的数据 # 读取当前这一轮生成器的数据
for idx in range(N): for idx in range(tc):
com_count = count[idx] com_count = count[idx]
for i in range(com_count): for i in range(com_count):
com_start = flatten_idx * H * W com_start = flatten_idx * H * W
@ -98,16 +108,16 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
topo_sim = sim[flatten_idx * 2 + 1] topo_sim = sim[flatten_idx * 2 + 1]
com_data = com[com_start:com_end] com_data = com[com_start:com_end]
flatten_idx += 1 flatten_idx += 1
com_map = np.fromiter(com_data, np.int8).view(H, W) com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
# map1, map2, vision_similarity, topo_similarity, is_review # map1, map2, vision_similarity, topo_similarity, is_review
res.append((maps[idx], com_map, vis_sim, topo_sim, False)) res.append((maps[idx], com_map, vis_sim, topo_sim, False))
return res return res
def train(): def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json') args = parse_arguments()
ginka = GinkaModel() ginka = GinkaModel()
ginka.to(device) ginka.to(device)
@ -119,8 +129,8 @@ def train():
# 准备数据集 # 准备数据集
ginka_dataset = GinkaDataset(args.train, device, minamo) ginka_dataset = GinkaDataset(args.train, device, minamo)
ginka_dataset_val = GinkaDataset(args.validate, device, minamo) ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
minamo_dataset = MinamoGANDataset() minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
minamo_dataset_val = MinamoGANDataset() minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True) ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True) ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True) minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
@ -132,7 +142,7 @@ def train():
criterion_ginka = GinkaLoss(minamo) criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3) optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6) scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss() criterion_minamo = MinamoLoss()
# 用于生成图片 # 用于生成图片
@ -142,10 +152,16 @@ def train():
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
# 与 node 端通讯 # 与 node 端通讯
if os.path.exists(SOCKET_PATH):
os.remove(SOCKET_PATH)
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(SOCKET_PATH) server.bind(SOCKET_PATH)
server.listen(1) server.listen(1)
print("Waiting for client connection...")
conn, _ = server.accept()
print("Client connected.")
if args.resume: if args.resume:
data = torch.load(args.from_state, map_location=device) data = torch.load(args.from_state, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False) ginka.load_state_dict(data["model_state"], strict=False)
@ -157,11 +173,11 @@ def train():
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion_ginka.weight[0] = 0.0 criterion_ginka.weight[0] = 0.0
for cycle in tqdm(range(args.from_cycle, args.to_cycle)): for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
# -------------------- 训练生成器 # -------------------- 训练生成器
gen_list: np.ndarray = np.empty(np.int8) gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
prob_list: np.ndarray = np.empty(np.float32) prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"): for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"):
ginka.train() ginka.train()
minamo.eval() minamo.eval()
total_loss = 0 total_loss = 0
@ -170,7 +186,7 @@ def train():
if not args.resume and epoch == 10: if not args.resume and epoch == 10:
criterion_ginka.weight[0] = 0.5 criterion_ginka.weight[0] = 0.5
for batch in ginka_dataloader: for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
# 数据迁移到设备 # 数据迁移到设备
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
# 前向传播 # 前向传播
@ -187,14 +203,14 @@ def train():
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
# 学习率调整 # 学习率调整
scheduler_ginka.step() scheduler_ginka.step(epoch + 1)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
loss_val = 0 loss_val = 0
ginka.eval() ginka.eval()
idx = 0 idx = 0
with torch.no_grad(): with torch.no_grad():
for batch in ginka_dataloader_val: for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
output, output_softmax = ginka(feat_vec) output, output_softmax = ginka(feat_vec)
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat) losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
@ -202,9 +218,9 @@ def train():
if epoch + 1 == EPOCHS_GINKA: if epoch + 1 == EPOCHS_GINKA:
# 最后一次验证的时候顺带生成图片 # 最后一次验证的时候顺带生成图片
prob = output_softmax.cpu().numpy() prob = output_softmax.cpu().numpy()
np.concatenate((prob_list, prob), axis=1) prob_list = np.concatenate((prob_list, prob), axis=0)
map_matrix = torch.argmax(output, dim=1).cpu().numpy() map_matrix = torch.argmax(output, dim=1).cpu().numpy()
gen_list = np.concatenate((gen_list, map_matrix), axis=1) gen_list = np.concatenate((gen_list, map_matrix), axis=0)
for matrix in map_matrix: for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict) image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image) cv2.imwrite(f"result/ginka_img/{idx}.png", image)
@ -232,8 +248,8 @@ def train():
buf.extend(struct.pack('>b', H)) # Map height buf.extend(struct.pack('>b', H)) # Map height
buf.extend(struct.pack('>b', W)) # Map width buf.extend(struct.pack('>b', W)) # Map width
buf.extend(gen_bytes) # Map tensor buf.extend(gen_bytes) # Map tensor
server.sendall(buf) conn.sendall(buf)
data = parse_minamo_data(server, prob_list) data = parse_minamo_data(conn, prob_list)
minamo_dataset.set_data(data) minamo_dataset.set_data(data)
# -------------------- 训练判别器 # -------------------- 训练判别器
@ -242,7 +258,7 @@ def train():
minamo.train() minamo.train()
total_loss = 0 total_loss = 0
for batch in minamo_dataloader: for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch) map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
if map1.shape[0] == 1: if map1.shape[0] == 1:
@ -265,16 +281,16 @@ def train():
total_loss += loss.item() total_loss += loss.item()
ave_loss = total_loss / len(minamo_dataloader) ave_loss = total_loss / len(minamo_dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}") tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
scheduler_minamo.step() scheduler_minamo.step(epoch + 1)
# 每十轮推理一次验证集 # 每十轮推理一次验证集
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
minamo.eval() minamo.eval()
val_loss = 0 val_loss = 0
with torch.no_grad(): with torch.no_grad():
for val_batch in tqdm(minamo_dataloader_val, leave=False): for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch) map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
vision_feat1, topo_feat1 = minamo(map1_val, graph1) vision_feat1, topo_feat1 = minamo(map1_val, graph1)

View File

@ -1,4 +1,5 @@
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm
class MinamoLoss(nn.Module): class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0.2, topo_weight=0.8): def __init__(self, vision_weight=0.2, topo_weight=0.8):
@ -9,7 +10,7 @@ class MinamoLoss(nn.Module):
def forward(self, vis_pred, topo_pred, vis_true, topo_true): def forward(self, vis_pred, topo_pred, vis_true, topo_true):
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape) # print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item()) # tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}")
vis_loss = self.loss(vis_pred, vis_true) vis_loss = self.loss(vis_pred, vis_true)
topo_loss = self.loss(topo_pred, topo_true) topo_loss = self.loss(topo_pred, topo_true)
# print(vis_loss.item(), topo_loss.item()) # print(vis_loss.item(), topo_loss.item())