From ca068bbea30fc95792cef4f75dfa64a0f608f9ff Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 22 Mar 2025 18:19:24 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9=E4=B8=8B=E9=87=87?= =?UTF-8?q?=E6=A0=B7=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/model/model.py | 4 ++-- ginka/model/sample.py | 2 +- ginka/train.py | 10 +++++----- ginka/validate.py | 6 +++--- minamo/dataset.py | 19 +++++++++++++++++++ minamo/train.py | 2 +- 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/ginka/model/model.py b/ginka/model/model.py index 5f8906e..7df6911 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -26,6 +26,6 @@ class GinkaModel(nn.Module): x = self.fc(feat) x = x.view(-1, self.base_ch, 32, 32) x = self.unet(x) - x = self.down_sample(x) - return F.softmax(x, dim=1) + x = F.interpolate(x, (13, 13), mode='bilinear') + return x, F.softmax(x, dim=1) \ No newline at end of file diff --git a/ginka/model/sample.py b/ginka/model/sample.py index 2af007d..1dd5193 100644 --- a/ginka/model/sample.py +++ b/ginka/model/sample.py @@ -7,7 +7,7 @@ class MapDownSample(nn.Module): self.down = nn.Sequential( nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1), nn.ReLU(), - nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0) + nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0) ) def forward(self, x): diff --git a/ginka/train.py b/ginka/train.py index 4e069bb..be1fe97 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -48,7 +48,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-4) + optimizer = optim.AdamW(model.parameters(), lr=1e-3) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = GinkaLoss(minamo) @@ -75,10 +75,10 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 optimizer.zero_grad() - output = model(feat_vec) + _, output_softmax = model(feat_vec) # 计算损失 - scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat) + scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) # 反向传播 scaled_losses.backward() @@ -115,11 +115,11 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 - output = model(feat_vec) + output, output_softmax = model(feat_vec) print(torch.argmax(output, dim=1)[0]) # 计算损失 - scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat) + scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) loss_val += losses.item() avg_val_loss = loss_val / len(dataloader_val) diff --git a/ginka/validate.py b/ginka/validate.py index 76a1fbc..e100603 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() - state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"] + state = torch.load("result/ginka.pth", map_location=device)["model_state"] model.load_state_dict(state) model.to(device) @@ -108,7 +108,7 @@ def validate(): target_topo_feat = batch["target_topo_feat"].to(device) feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 - output = model(feat_vec) + output, output_softmax = model(feat_vec) map_matrix = torch.argmax(output, dim=1) for matrix in map_matrix[:].cpu(): @@ -118,7 +118,7 @@ def validate(): idx += 1 # 计算损失 - _, loss = criterion(output, target, target_vision_feat, target_topo_feat) + _, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) diff --git a/minamo/dataset.py b/minamo/dataset.py index 0c55387..a543ff5 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -4,6 +4,22 @@ import torch.nn.functional as F from torch.utils.data import Dataset from shared.graph import convert_soft_map_to_graph +def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8): + """ + 生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动 + """ + C, H, W = onehot_map.shape + # 生成主类别的随机概率 (min_main, max_main) + main_prob = torch.rand(H, W) * (max_main - min_main) + min_main + + # 计算剩余概率并随机分配到其他类别 + noise = torch.rand(C, H, W) * epsilon # 随机噪声 + noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon + + # 计算最终平滑 one-hot 结果 + smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise + return smooth_onehot + def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: data = json.load(f) @@ -27,6 +43,9 @@ class MinamoDataset(Dataset): map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + map1_probs = random_smooth_onehot(map1_probs) + map2_probs = random_smooth_onehot(map2_probs) + graph1 = convert_soft_map_to_graph(map1_probs) graph2 = convert_soft_map_to_graph(map2_probs) diff --git a/minamo/train.py b/minamo/train.py index e84bb14..eaac040 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -79,7 +79,7 @@ def train(): # for name, param in model.named_parameters(): # param.requires_grad = True - for batch in dataloader: + for batch in tqdm(dataloader, leave=False): # 数据迁移到设备 map1, map2, vision_simi, topo_simi, graph1, graph2 = batch map1 = map1.to(device) # 转为 [B, C, H, W]