feat: 提高参数量

This commit is contained in:
unanmed 2025-03-24 16:59:53 +08:00
parent 724b6612d3
commit f169167409
11 changed files with 65 additions and 38 deletions

View File

@ -1,4 +1,6 @@
for i in {$1...$2}
start=$1
end=$2
for ((i=start; i<=end; i=i+1))
do
sh gan.sh "$i"
echo "$i 次循环完成"

View File

@ -95,8 +95,8 @@ function generateTransformData(
types.push([rot, flip]);
}
}
// 随机抽取最多
const trans = chooseFrom(types, Math.floor(Math.random() * 2));
// 随机抽取最多
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
return trans
.map(([rot, flip]) => {
const com1 = `${id1}.${rot}.${flip}:${id1}`;
@ -167,10 +167,10 @@ function generateTransformData(
}
function generateSimilarData(id: string, map: number[][]) {
// 生成最多个微调地图
// 生成最多个微调地图
const width = map[0].length;
const height = map.length;
const num = Math.floor(Math.random() * 3);
const num = Math.floor(Math.random() * 2);
const res: [id: string, data: MinamoTrainData][] = [];
for (let i = 0; i < num; i++) {
@ -241,7 +241,7 @@ function generatePair(
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 3));
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
const selfTrain1: MinamoTrainData = {
map1: map1,

6
gan.sh
View File

@ -8,10 +8,10 @@ python3 -m ginka.validate
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
cd data
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
cd ..

View File

@ -3,7 +3,8 @@ import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from minamo.model.model import MinamoModel
from shared.graph import convert_soft_map_to_graph
from shared.graph import differentiable_convert_to_data
from shared.utils import random_smooth_onehot
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
@ -28,8 +29,9 @@ class GinkaDataset(Dataset):
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float().to(self.device) # [32, H, W]
graph = convert_soft_map_to_graph(target).to(self.device)
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
target = random_smooth_onehot(target).to(self.device)
graph = differentiable_convert_to_data(target).to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return {

View File

@ -10,8 +10,22 @@ class GinkaModel(nn.Module):
"""
super().__init__()
self.base_ch = base_ch
fc_dim = base_ch * 8 * 4 * 4
self.fc = nn.Sequential(
nn.Linear(feat_dim, 32 * 32 * base_ch)
nn.Linear(feat_dim, fc_dim),
nn.BatchNorm1d(fc_dim),
nn.ReLU()
)
self.deconv_layers = nn.Sequential(
nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=4, stride=2, padding=1), # Upsample 2x
nn.BatchNorm2d(base_ch*4),
nn.ReLU(),
nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=4, stride=2, padding=1), # Upsample 2x
nn.BatchNorm2d(base_ch*2),
nn.ReLU(),
nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=4, stride=2, padding=1), # Upsample 2x
nn.BatchNorm2d(base_ch),
nn.ReLU(),
)
self.unet = GinkaUNet(base_ch, num_classes)
self.down_sample = MapDownSample(num_classes, num_classes)
@ -25,7 +39,8 @@ class GinkaModel(nn.Module):
logits: 输出logits [BS, num_classes, H, W]
"""
x = self.fc(feat)
x = x.view(-1, self.base_ch, 32, 32)
x = x.view(-1, self.base_ch*8, 4, 4)
x = self.deconv_layers(x)
x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear')
return x, F.softmax(x, dim=1)

View File

@ -48,7 +48,7 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
@ -72,7 +72,7 @@ def train():
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
optimizer.zero_grad()
_, output_softmax = model(feat_vec)
@ -84,6 +84,10 @@ def train():
scaled_losses.backward()
optimizer.step()
total_loss += losses.item()
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
@ -112,7 +116,7 @@ def train():
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
output, output_softmax = model(feat_vec)

View File

@ -106,7 +106,7 @@ def validate():
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
output, output_softmax = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)

View File

@ -3,22 +3,7 @@ import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from shared.graph import differentiable_convert_to_data
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
"""
生成随机平滑的 one-hot 编码使主类别概率不再固定而是随机波动
"""
C, H, W = onehot_map.shape
# 生成主类别的随机概率 (min_main, max_main)
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
# 计算剩余概率并随机分配到其他类别
noise = torch.rand(C, H, W) * epsilon # 随机噪声
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
# 计算最终平滑 one-hot 结果
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
return smooth_onehot
from shared.utils import random_smooth_onehot
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:

View File

@ -110,7 +110,7 @@ def train():
total_loss += loss.item()
ave_loss = total_loss / len(dataloader)
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0
# for p in model.parameters():
@ -128,7 +128,7 @@ def train():
scheduler.step()
# 每十轮推理一次验证集
if (epoch + 1) % 1 == 0:
if (epoch + 1) % 5 == 0:
model.eval()
val_loss = 0
with torch.no_grad():
@ -152,7 +152,7 @@ def train():
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),

View File

@ -20,7 +20,7 @@ def validate():
print(f"Total parameters: {total_params}")
# 准备数据集
val_dataset = MinamoDataset("minamo-eval.json")
val_dataset = MinamoDataset("datasets/minamo-eval-1.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
@ -44,6 +44,8 @@ def validate():
vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
loss_val = criterion(

17
shared/utils.py Normal file
View File

@ -0,0 +1,17 @@
import torch
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
"""
生成随机平滑的 one-hot 编码使主类别概率不再固定而是随机波动
"""
C, H, W = onehot_map.shape
# 生成主类别的随机概率 (min_main, max_main)
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
# 计算剩余概率并随机分配到其他类别
noise = torch.rand(C, H, W) * epsilon # 随机噪声
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
# 计算最终平滑 one-hot 结果
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
return smooth_onehot