mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-18 15:41:11 +08:00
feat: 提高参数量
This commit is contained in:
parent
724b6612d3
commit
f169167409
@ -1,4 +1,6 @@
|
||||
for i in {$1...$2}
|
||||
start=$1
|
||||
end=$2
|
||||
for ((i=start; i<=end; i=i+1))
|
||||
do
|
||||
sh gan.sh "$i"
|
||||
echo "第 $i 次循环完成"
|
||||
|
||||
@ -95,8 +95,8 @@ function generateTransformData(
|
||||
types.push([rot, flip]);
|
||||
}
|
||||
}
|
||||
// 随机抽取最多两个
|
||||
const trans = chooseFrom(types, Math.floor(Math.random() * 2));
|
||||
// 随机抽取最多一个
|
||||
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
||||
return trans
|
||||
.map(([rot, flip]) => {
|
||||
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
||||
@ -167,10 +167,10 @@ function generateTransformData(
|
||||
}
|
||||
|
||||
function generateSimilarData(id: string, map: number[][]) {
|
||||
// 生成最多五个微调地图
|
||||
// 生成最多两个微调地图
|
||||
const width = map[0].length;
|
||||
const height = map.length;
|
||||
const num = Math.floor(Math.random() * 3);
|
||||
const num = Math.floor(Math.random() * 2);
|
||||
const res: [id: string, data: MinamoTrainData][] = [];
|
||||
|
||||
for (let i = 0; i < num; i++) {
|
||||
@ -241,7 +241,7 @@ function generatePair(
|
||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||
const self1 = `${id1}:${id1}`;
|
||||
const self2 = `${id2}:${id2}`;
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 3));
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
||||
const selfTrain1: MinamoTrainData = {
|
||||
map1: map1,
|
||||
|
||||
6
gan.sh
6
gan.sh
@ -8,10 +8,10 @@ python3 -m ginka.validate
|
||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
||||
cd data
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30
|
||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
||||
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
|
||||
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
|
||||
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
||||
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
||||
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
|
||||
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
|
||||
cd ..
|
||||
|
||||
@ -3,7 +3,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import convert_soft_map_to_graph
|
||||
from shared.graph import differentiable_convert_to_data
|
||||
from shared.utils import random_smooth_onehot
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
@ -28,8 +29,9 @@ class GinkaDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float().to(self.device) # [32, H, W]
|
||||
graph = convert_soft_map_to_graph(target).to(self.device)
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
target = random_smooth_onehot(target).to(self.device)
|
||||
graph = differentiable_convert_to_data(target).to(self.device)
|
||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||
|
||||
return {
|
||||
|
||||
@ -10,8 +10,22 @@ class GinkaModel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.base_ch = base_ch
|
||||
fc_dim = base_ch * 8 * 4 * 4
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
||||
nn.Linear(feat_dim, fc_dim),
|
||||
nn.BatchNorm1d(fc_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.deconv_layers = nn.Sequential(
|
||||
nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=4, stride=2, padding=1), # Upsample 2x
|
||||
nn.BatchNorm2d(base_ch*4),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=4, stride=2, padding=1), # Upsample 2x
|
||||
nn.BatchNorm2d(base_ch*2),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=4, stride=2, padding=1), # Upsample 2x
|
||||
nn.BatchNorm2d(base_ch),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.unet = GinkaUNet(base_ch, num_classes)
|
||||
self.down_sample = MapDownSample(num_classes, num_classes)
|
||||
@ -25,7 +39,8 @@ class GinkaModel(nn.Module):
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
x = self.fc(feat)
|
||||
x = x.view(-1, self.base_ch, 32, 32)
|
||||
x = x.view(-1, self.base_ch*8, 4, 4)
|
||||
x = self.deconv_layers(x)
|
||||
x = self.unet(x)
|
||||
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||
return x, F.softmax(x, dim=1)
|
||||
|
||||
@ -48,7 +48,7 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
@ -72,7 +72,7 @@ def train():
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
_, output_softmax = model(feat_vec)
|
||||
@ -84,6 +84,10 @@ def train():
|
||||
scaled_losses.backward()
|
||||
optimizer.step()
|
||||
total_loss += losses.item()
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
@ -112,7 +116,7 @@ def train():
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
|
||||
@ -106,7 +106,7 @@ def validate():
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
@ -3,22 +3,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from shared.graph import differentiable_convert_to_data
|
||||
|
||||
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
|
||||
"""
|
||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||
"""
|
||||
C, H, W = onehot_map.shape
|
||||
# 生成主类别的随机概率 (min_main, max_main)
|
||||
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
||||
|
||||
# 计算剩余概率并随机分配到其他类别
|
||||
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
||||
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
||||
|
||||
# 计算最终平滑 one-hot 结果
|
||||
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
||||
return smooth_onehot
|
||||
from shared.utils import random_smooth_onehot
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
|
||||
@ -110,7 +110,7 @@ def train():
|
||||
total_loss += loss.item()
|
||||
|
||||
ave_loss = total_loss / len(dataloader)
|
||||
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
@ -128,7 +128,7 @@ def train():
|
||||
scheduler.step()
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 1 == 0:
|
||||
if (epoch + 1) % 5 == 0:
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
@ -152,7 +152,7 @@ def train():
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
|
||||
@ -20,7 +20,7 @@ def validate():
|
||||
print(f"Total parameters: {total_params}")
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = MinamoDataset("minamo-eval.json")
|
||||
val_dataset = MinamoDataset("datasets/minamo-eval-1.json")
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
@ -44,6 +44,8 @@ def validate():
|
||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||
|
||||
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
|
||||
|
||||
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
loss_val = criterion(
|
||||
|
||||
17
shared/utils.py
Normal file
17
shared/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
|
||||
"""
|
||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||
"""
|
||||
C, H, W = onehot_map.shape
|
||||
# 生成主类别的随机概率 (min_main, max_main)
|
||||
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
||||
|
||||
# 计算剩余概率并随机分配到其他类别
|
||||
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
||||
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
||||
|
||||
# 计算最终平滑 one-hot 结果
|
||||
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
||||
return smooth_onehot
|
||||
Loading…
Reference in New Issue
Block a user