diff --git a/cycle2.sh b/cycle2.sh index 98187c8..84011e9 100644 --- a/cycle2.sh +++ b/cycle2.sh @@ -1,4 +1,6 @@ -for i in {$1...$2} +start=$1 +end=$2 +for ((i=start; i<=end; i=i+1)) do sh gan.sh "$i" echo "第 $i 次循环完成" diff --git a/data/src/minamo.ts b/data/src/minamo.ts index 9d7980d..dd6578e 100644 --- a/data/src/minamo.ts +++ b/data/src/minamo.ts @@ -95,8 +95,8 @@ function generateTransformData( types.push([rot, flip]); } } - // 随机抽取最多两个 - const trans = chooseFrom(types, Math.floor(Math.random() * 2)); + // 随机抽取最多一个 + const trans = chooseFrom(types, Math.floor(Math.random() * 1)); return trans .map(([rot, flip]) => { const com1 = `${id1}.${rot}.${flip}:${id1}`; @@ -167,10 +167,10 @@ function generateTransformData( } function generateSimilarData(id: string, map: number[][]) { - // 生成最多五个微调地图 + // 生成最多两个微调地图 const width = map[0].length; const height = map.length; - const num = Math.floor(Math.random() * 3); + const num = Math.floor(Math.random() * 2); const res: [id: string, data: MinamoTrainData][] = []; for (let i = 0; i < num; i++) { @@ -241,7 +241,7 @@ function generatePair( // 自身与自身对比的训练集,保证模型对相同地图输出 1 const self1 = `${id1}:${id1}`; const self2 = `${id2}:${id2}`; - const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 3)); + const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) { const selfTrain1: MinamoTrainData = { map1: map1, diff --git a/gan.sh b/gan.sh index 4616733..c8caf20 100644 --- a/gan.sh +++ b/gan.sh @@ -8,10 +8,10 @@ python3 -m ginka.validate mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" mv "minamo-eval.json" "datasets/minamo-eval-$1.json" cd data -pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40 +pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30 pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10 -pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json" -pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json" pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json" pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json" +pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json" +pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json" cd .. diff --git a/ginka/dataset.py b/ginka/dataset.py index 10dbfb4..0fa3502 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -3,7 +3,8 @@ import torch import torch.nn.functional as F from torch.utils.data import Dataset from minamo.model.model import MinamoModel -from shared.graph import convert_soft_map_to_graph +from shared.graph import differentiable_convert_to_data +from shared.utils import random_smooth_onehot def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: @@ -28,8 +29,9 @@ class GinkaDataset(Dataset): def __getitem__(self, idx): item = self.data[idx] - target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float().to(self.device) # [32, H, W] - graph = convert_soft_map_to_graph(target).to(self.device) + target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + target = random_smooth_onehot(target).to(self.device) + graph = differentiable_convert_to_data(target).to(self.device) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) return { diff --git a/ginka/model/model.py b/ginka/model/model.py index eddc047..6415dbd 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -10,8 +10,22 @@ class GinkaModel(nn.Module): """ super().__init__() self.base_ch = base_ch + fc_dim = base_ch * 8 * 4 * 4 self.fc = nn.Sequential( - nn.Linear(feat_dim, 32 * 32 * base_ch) + nn.Linear(feat_dim, fc_dim), + nn.BatchNorm1d(fc_dim), + nn.ReLU() + ) + self.deconv_layers = nn.Sequential( + nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=4, stride=2, padding=1), # Upsample 2x + nn.BatchNorm2d(base_ch*4), + nn.ReLU(), + nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=4, stride=2, padding=1), # Upsample 2x + nn.BatchNorm2d(base_ch*2), + nn.ReLU(), + nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=4, stride=2, padding=1), # Upsample 2x + nn.BatchNorm2d(base_ch), + nn.ReLU(), ) self.unet = GinkaUNet(base_ch, num_classes) self.down_sample = MapDownSample(num_classes, num_classes) @@ -25,7 +39,8 @@ class GinkaModel(nn.Module): logits: 输出logits [BS, num_classes, H, W] """ x = self.fc(feat) - x = x.view(-1, self.base_ch, 32, 32) + x = x.view(-1, self.base_ch*8, 4, 4) + x = self.deconv_layers(x) x = self.unet(x) x = F.interpolate(x, (13, 13), mode='bilinear') return x, F.softmax(x, dim=1) diff --git a/ginka/train.py b/ginka/train.py index 61f9d1c..96d5d35 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -48,7 +48,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=5e-3) + optimizer = optim.AdamW(model.parameters(), lr=1e-3) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = GinkaLoss(minamo) @@ -72,7 +72,7 @@ def train(): target = batch["target"].to(device) target_vision_feat = batch["target_vision_feat"].to(device) target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 optimizer.zero_grad() _, output_softmax = model(feat_vec) @@ -84,6 +84,10 @@ def train(): scaled_losses.backward() optimizer.step() total_loss += losses.item() + # for name, param in model.named_parameters(): + # if param.grad is not None: + # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") + avg_loss = total_loss / len(dataloader) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") @@ -112,7 +116,7 @@ def train(): target = batch["target"].to(device) target_vision_feat = batch["target_vision_feat"].to(device) target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 output, output_softmax = model(feat_vec) diff --git a/ginka/validate.py b/ginka/validate.py index e100603..89ef941 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -106,7 +106,7 @@ def validate(): target = batch["target"].to(device) target_vision_feat = batch["target_vision_feat"].to(device) target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 output, output_softmax = model(feat_vec) map_matrix = torch.argmax(output, dim=1) diff --git a/minamo/dataset.py b/minamo/dataset.py index 4dda47c..0aa8060 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -3,22 +3,7 @@ import torch import torch.nn.functional as F from torch.utils.data import Dataset from shared.graph import differentiable_convert_to_data - -def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25): - """ - 生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动 - """ - C, H, W = onehot_map.shape - # 生成主类别的随机概率 (min_main, max_main) - main_prob = torch.rand(H, W) * (max_main - min_main) + min_main - - # 计算剩余概率并随机分配到其他类别 - noise = torch.rand(C, H, W) * epsilon # 随机噪声 - noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon - - # 计算最终平滑 one-hot 结果 - smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise - return smooth_onehot +from shared.utils import random_smooth_onehot def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: diff --git a/minamo/train.py b/minamo/train.py index e4416ad..24f2dcd 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -110,7 +110,7 @@ def train(): total_loss += loss.item() ave_loss = total_loss / len(dataloader) - print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") # total_norm = 0 # for p in model.parameters(): @@ -128,7 +128,7 @@ def train(): scheduler.step() # 每十轮推理一次验证集 - if (epoch + 1) % 1 == 0: + if (epoch + 1) % 5 == 0: model.eval() val_loss = 0 with torch.no_grad(): @@ -152,7 +152,7 @@ def train(): val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader) - print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") torch.save({ "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), diff --git a/minamo/validate.py b/minamo/validate.py index d8e7e3a..8eae08f 100644 --- a/minamo/validate.py +++ b/minamo/validate.py @@ -20,7 +20,7 @@ def validate(): print(f"Total parameters: {total_params}") # 准备数据集 - val_dataset = MinamoDataset("minamo-eval.json") + val_dataset = MinamoDataset("datasets/minamo-eval-1.json") val_loader = DataLoader( val_dataset, batch_size=32, @@ -44,6 +44,8 @@ def validate(): vision_feat1, topo_feat1 = model(map1_val, graph1) vision_feat2, topo_feat2 = model(map2_val, graph2) + print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item()) + vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) loss_val = criterion( diff --git a/shared/utils.py b/shared/utils.py new file mode 100644 index 0000000..4c00e88 --- /dev/null +++ b/shared/utils.py @@ -0,0 +1,17 @@ +import torch + +def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25): + """ + 生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动 + """ + C, H, W = onehot_map.shape + # 生成主类别的随机概率 (min_main, max_main) + main_prob = torch.rand(H, W) * (max_main - min_main) + min_main + + # 计算剩余概率并随机分配到其他类别 + noise = torch.rand(C, H, W) * epsilon # 随机噪声 + noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon + + # 计算最终平滑 one-hot 结果 + smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise + return smooth_onehot \ No newline at end of file