fix: 浮点数被意外转成 LongTensor

This commit is contained in:
unanmed 2025-04-03 12:12:13 +08:00
parent 325eb599c3
commit 6a1aeaa77e
7 changed files with 42 additions and 63 deletions

View File

@ -35,7 +35,7 @@ function generateGANData(
map: number[][]
) {
const id2 = `$${id++}`;
const toTrain = chooseFrom(keys, 30);
const toTrain = chooseFrom(keys, 4);
const data = toTrain.map<MinamoTrainData[]>(v => {
const floor = refer.get(v);
if (!floor) return [];

View File

@ -87,7 +87,8 @@ function weisfeilerLehmanIteration(
});
weight *= decay;
});
// 把每个节点的原始标签也加上,权重使用最远权重,可以认为是资源重复率
// 把每个节点的原始标签也加上权重使用最远权重再衰减2次可以认为是资源重复率
weight *= decay ** 2;
nodes.forEach(node => {
if (!numMap.has(node.originalLabel)) {
numMap.set(node.originalLabel, weight);

View File

@ -37,7 +37,10 @@ class GinkaDataset(Dataset):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
target_smooth = random_smooth_onehot(target)
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
graph = differentiable_convert_to_data(target_smooth).to(self.device)
target = target.to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
@ -67,12 +70,12 @@ class MinamoGANDataset(Dataset):
item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item
map1 = torch.LongTensor(map1)
map2 = torch.LongTensor(map2)
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
if review:
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
else:
map1 = torch.FloatTensor(map1)
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)

View File

@ -249,9 +249,10 @@ class GinkaLoss(nn.Module):
graph = batch_convert_soft_map_to_graph(pred)
pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=1)
minamo_sim = 0 * vision_sim + 1 * topo_sim
# tqdm.write(f"{vision_sim.mean().item():.12f}, {topo_sim.mean().item():.12f}")
minamo_loss = (1.0 - minamo_sim).mean()
tqdm.write(

View File

@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv
BATCH_SIZE = 32
EPOCHS_GINKA = 30
EPOCHS_MINAMO = 15
EPOCHS_MINAMO = 10
SOCKET_PATH = "./tmp/ginka_uds"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -40,9 +40,9 @@ def parse_arguments():
def parse_ginka_batch(batch):
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1)
target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device)
return target, target_vision_feat, target_topo_feat, feat_vec
@ -133,8 +133,8 @@ def train():
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True)
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
# 设定优化器与调度器
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
@ -142,7 +142,7 @@ def train():
criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss()
# 用于生成图片
@ -158,20 +158,18 @@ def train():
server.bind(SOCKET_PATH)
server.listen(1)
print("Waiting for client connection...")
conn, _ = server.accept()
print("Client connected.")
if args.resume:
data = torch.load(args.from_state, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
if args.load_optim:
optimizer_ginka.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
else:
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion_ginka.weight[0] = 0.0
print("Waiting for client connection...")
conn, _ = server.accept()
print("Client connected.")
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
# -------------------- 训练生成器
@ -217,10 +215,7 @@ def train():
loss_val += losses.item()
if epoch + 1 == EPOCHS_GINKA:
# 最后一次验证的时候顺带生成图片
prob = output_softmax.cpu().numpy()
prob_list = np.concatenate((prob_list, prob), axis=0)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
@ -231,6 +226,16 @@ def train():
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
# 使用训练集生成 minamo 训练数据,更准确
with torch.no_grad():
for batch in ginka_dataloader:
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
output, output_softmax = ginka(feat_vec)
prob = output_softmax.cpu().numpy()
prob_list = np.concatenate((prob_list, prob), axis=0)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
tqdm.write(f"Cycle {cycle} Ginka train ended.")
torch.save({
@ -269,8 +274,8 @@ def train():
vision_feat1, topo_feat1 = minamo(map1, graph1)
vision_feat2, topo_feat2 = minamo(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
# 计算损失
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
@ -296,8 +301,8 @@ def train():
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
# 计算损失
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
@ -312,7 +317,7 @@ def train():
tqdm.write(f"Cycle {cycle} Minamo train ended.")
torch.save({
"model_state": minamo.state_dict()
}, f"result/ginka.pth")
}, f"result/minamo.pth")
print("Train ended.")

View File

@ -10,8 +10,8 @@ class MinamoLoss(nn.Module):
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}")
vis_loss = self.loss(vis_pred, vis_true)
topo_loss = self.loss(topo_pred, topo_true)
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f} | {vis_loss.item():.12f}, {topo_loss.item():.12f}")
# print(vis_loss.item(), topo_loss.item())
return self.vision_weight * vis_loss + self.topo_weight * topo_loss

View File

@ -52,37 +52,6 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
num_nodes=N
)
def convert_soft_map_to_graph(map_probs: torch.Tensor):
"""
直接使用 Softmax 概率构建 soft 图结构
"""
C, H, W = map_probs.shape # [32, H, W]
N = H * W
device = map_probs.device
# 计算 soft 节点特征
node_features = map_probs.view(C, N).T # [N, C]
# 计算 soft 邻接边(基于 soft 权重)
edge_list = []
for r in range(H):
for c in range(W):
node = r * W + c
if c + 1 < W:
right = node + 1
edge_list.append([node, right])
if r + 1 < H:
down = node + W
edge_list.append([node, down])
edge_index = torch.tensor(edge_list).t().to(device)
# 计算 soft 边权重(基于 Softmax 概率)
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
def batch_convert_soft_map_to_graph(batch_map_probs):
"""
处理 batch 维度 [B, C, H, W] 转换为批量图结构 Batch