mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: UNet 部分重写并改为条件注入模式
This commit is contained in:
parent
49ee543732
commit
5669f49af0
2
gan.sh
2
gan.sh
@ -2,7 +2,7 @@
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m ginka.train --epochs 30 --resume true
|
||||
python3 -m ginka.train --epochs 70 --resume true
|
||||
python3 -m ginka.validate
|
||||
# 训练完毕,处理数据
|
||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ResidualUpsampleBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaInput(nn.Module):
|
||||
def __init__(self, feat_dim=1024, out_ch=64):
|
||||
super().__init__()
|
||||
fc_dim = out_ch * 8 * 4 * 4
|
||||
self.out_ch = out_ch
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, fc_dim),
|
||||
nn.BatchNorm1d(fc_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.upsample = nn.Sequential(
|
||||
ResidualUpsampleBlock(out_ch*8, out_ch*8),
|
||||
ResidualUpsampleBlock(out_ch*8, out_ch*4),
|
||||
ResidualUpsampleBlock(out_ch*4, out_ch)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
x = x.view(-1, self.out_ch*8, 4, 4)
|
||||
x = self.upsample(x)
|
||||
return x
|
||||
@ -117,71 +117,77 @@ def adaptive_count_loss(
|
||||
class_list: list = list(range(32)),
|
||||
margin_ratio: float = 0.2,
|
||||
zero_margin_scale: float = 0.2,
|
||||
lambda_entropy: float = 0.05,
|
||||
lambda_local: float = 0.1,
|
||||
grid_size: int = 8,
|
||||
eps: float = 1e-3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
自适应图块数量约束损失函数
|
||||
改进版自适应图块数量约束损失,包含局部匹配和熵约束
|
||||
|
||||
参数:
|
||||
pred_probs: 预测概率分布 [B, C, H, W]
|
||||
target_map: 真实地图 [B, C, H, W]
|
||||
class_list: 需要约束的类别列表
|
||||
margin_ratio: 允许的相对误差范围(如0.2表示±20%)
|
||||
zero_margin_scale: 参考数量为0时的允许余量系数(余量=scale*sqrt(H*W))
|
||||
eps: 数值稳定性常数
|
||||
|
||||
返回:
|
||||
loss: 标量损失值
|
||||
"""
|
||||
B, C, H, W = pred_probs.shape
|
||||
device = pred_probs.device
|
||||
total_loss = 0.0
|
||||
valid_classes = 0
|
||||
|
||||
# 预计算地图面积用于余量计算
|
||||
# 预计算地图面积
|
||||
map_area = math.sqrt(H * W)
|
||||
|
||||
# 计算最小非零类别概率
|
||||
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() # 获取预测中的最小非零概率
|
||||
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area # 让零类别不被填充
|
||||
|
||||
for cls in class_list:
|
||||
# 预测数量(概率和)
|
||||
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # [B]
|
||||
# 真实数量
|
||||
true_count = target_map[:, cls].sum(dim=(1,2)) # [B]
|
||||
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测类别数量
|
||||
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实类别数量
|
||||
|
||||
# 动态容差计算
|
||||
with torch.no_grad():
|
||||
# 当真实数量为0时的允许上限
|
||||
zero_mask = (true_count == 0)
|
||||
dynamic_margin = torch.where(
|
||||
zero_mask,
|
||||
zero_margin_scale * map_area, # 允许存在少量
|
||||
margin_ratio * true_count # 相对误差范围
|
||||
)
|
||||
zero_mask = (true_count == 0)
|
||||
dynamic_margin = torch.where(
|
||||
zero_mask,
|
||||
dynamic_zero_margin,
|
||||
margin_ratio * true_count
|
||||
)
|
||||
|
||||
# 误差计算(考虑数值稳定性)
|
||||
safe_true = true_count + eps * zero_mask # 零真实值时添加微小量
|
||||
safe_true = true_count + eps * zero_mask
|
||||
abs_error = torch.abs(pred_count - true_count)
|
||||
rel_error = abs_error / safe_true
|
||||
|
||||
# 双阶段损失函数
|
||||
# 阶段一:误差在容差范围内时使用二次函数(强梯度)
|
||||
# 阶段二:超出容差时转为线性(稳定训练)
|
||||
loss_per_class = torch.where(
|
||||
abs_error <= dynamic_margin,
|
||||
(rel_error ** 2) * 0.5, # 区间内强梯度
|
||||
rel_error - (0.5 * margin_ratio) # 区间外稳定梯度
|
||||
(rel_error ** 2) * 0.8 + 0.2 * rel_error,
|
||||
rel_error - (0.5 * margin_ratio)
|
||||
)
|
||||
|
||||
# 零真实值特殊处理:仅惩罚超出余量部分
|
||||
loss_per_class = torch.where(
|
||||
zero_mask,
|
||||
F.relu(abs_error - dynamic_margin) / map_area, # 归一化处理
|
||||
F.relu(abs_error - dynamic_margin) / map_area,
|
||||
loss_per_class
|
||||
)
|
||||
|
||||
total_loss += loss_per_class.mean()
|
||||
valid_classes += 1
|
||||
|
||||
return total_loss / valid_classes # 类别平均
|
||||
# 平均类别损失
|
||||
total_loss /= valid_classes
|
||||
|
||||
# 加入负熵约束,防止类别均匀化
|
||||
def entropy_loss(pred_probs):
|
||||
avg_probs = pred_probs.mean(dim=(2, 3))
|
||||
entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=1)
|
||||
return entropy.mean()
|
||||
|
||||
total_loss += lambda_entropy * entropy_loss(pred_probs)
|
||||
|
||||
# 加入局部类别匹配
|
||||
def local_count_loss(pred_probs, target_probs, grid_size=8):
|
||||
pred_local = F.avg_pool2d(pred_probs, kernel_size=grid_size, stride=grid_size)
|
||||
target_local = F.avg_pool2d(target_probs, kernel_size=grid_size, stride=grid_size)
|
||||
return F.mse_loss(pred_local, target_local)
|
||||
|
||||
total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size)
|
||||
|
||||
return total_loss
|
||||
|
||||
def illegal_tile_loss(
|
||||
pred_probs: torch.Tensor,
|
||||
|
||||
@ -2,31 +2,48 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .input import GinkaInput
|
||||
from .output import GinkaOutput
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.input = GinkaInput(feat_dim, base_ch)
|
||||
self.unet = GinkaUNet(base_ch, num_classes)
|
||||
self.unet = GinkaUNet(1, base_ch, num_classes, feat_dim)
|
||||
self.output = GinkaOutput(num_classes, (13, 13))
|
||||
print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, feat):
|
||||
"""
|
||||
Args:
|
||||
feat: 参考地图的特征向量
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
x = self.input(x)
|
||||
x = self.unet(x)
|
||||
x = self.unet(x, feat)
|
||||
x = self.output(x)
|
||||
return x, F.softmax(x, dim=1)
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
x = torch.randn((1, 1, 32, 32)).cuda()
|
||||
feat = torch.randn((1, 1024)).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaModel().cuda()
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(x, feat)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: x={x.shape}, feat={feat.shape}")
|
||||
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -1,106 +1,128 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from shared.attention import CBAM, SEBlock
|
||||
|
||||
class GinkaAdaIN(nn.Module):
|
||||
def __init__(self, num_features, condition_dim):
|
||||
"""
|
||||
自适应实例归一化 (AdaIN)
|
||||
参数:
|
||||
num_features: 归一化的通道数
|
||||
condition_dim: 条件输入的特征维度
|
||||
"""
|
||||
super(GinkaAdaIN, self).__init__()
|
||||
self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""
|
||||
x: [B, C, H, W] - 输入特征图
|
||||
condition: [B, condition_dim] - 需要注入的条件向量
|
||||
"""
|
||||
gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β
|
||||
gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状
|
||||
beta = beta.view(x.shape[0], x.shape[1], 1, 1)
|
||||
|
||||
x = F.instance_norm(x) # 标准化
|
||||
return gamma * x + beta # 进行变换
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class AdaINConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, feat_dim):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||||
|
||||
def forward(self, x, feat):
|
||||
x = self.conv(x)
|
||||
x = self.adain(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
def __init__(self, in_channels, out_channels, attention=False, block='CBAM'):
|
||||
def __init__(self, in_ch, out_ch, feat_dim):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||||
|
||||
def forward(self, x, feat):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x = self.adain(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaUpSample(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
)
|
||||
# 注意力
|
||||
if attention:
|
||||
if block == 'CBAM':
|
||||
self.conv.append(CBAM(out_channels))
|
||||
elif block == 'SEBlock':
|
||||
self.conv.append(SEBlock(out_channels))
|
||||
self.conv.append(nn.GELU())
|
||||
self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x_res = self.conv(x)
|
||||
x_down = self.down(x_res)
|
||||
return x_down, x_res
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaDecoder(nn.Module):
|
||||
"""解码器(上采样)部分"""
|
||||
def __init__(self, in_channels, out_channels, attention=False, block='CBAM'):
|
||||
def __init__(self, in_ch, out_ch, feat_dim):
|
||||
super().__init__()
|
||||
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
)
|
||||
# 注意力
|
||||
if attention:
|
||||
if block == 'CBAM':
|
||||
self.conv.append(CBAM(out_channels))
|
||||
elif block == 'SEBlock':
|
||||
self.conv.append(SEBlock(out_channels))
|
||||
self.conv.append(nn.GELU())
|
||||
|
||||
def forward(self, x, skip):
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||||
|
||||
def forward(self, x, skip, feat):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.conv(x)
|
||||
x = self.adain(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaBottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, attention=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
)
|
||||
if attention:
|
||||
self.conv.append(SEBlock(out_channels))
|
||||
self.conv.append(nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=64, out_ch=32):
|
||||
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.down1 = GinkaEncoder(in_ch, in_ch*2, attention=True)
|
||||
self.down2 = GinkaEncoder(in_ch*2, in_ch*4, attention=True)
|
||||
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
|
||||
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
|
||||
self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim)
|
||||
self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim)
|
||||
self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim)
|
||||
self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim)
|
||||
|
||||
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
|
||||
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim)
|
||||
|
||||
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
|
||||
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
|
||||
self.up3 = GinkaDecoder(in_ch*4, in_ch*2, attention=True)
|
||||
self.up4 = GinkaDecoder(in_ch*2, in_ch, attention=True)
|
||||
self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim)
|
||||
self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim)
|
||||
self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim)
|
||||
self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 1),
|
||||
nn.Conv2d(base_ch, out_ch, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x_down1, skip1 = self.down1(x)
|
||||
x_down2, skip2 = self.down2(x_down1)
|
||||
x_down3, skip3 = self.down3(x_down2)
|
||||
x_down4, skip4 = self.down4(x_down3)
|
||||
|
||||
x = self.bottleneck(x_down4)
|
||||
|
||||
x = self.up1(x, skip4)
|
||||
x = self.up2(x, skip3)
|
||||
x = self.up3(x, skip2)
|
||||
x = self.up4(x, skip1)
|
||||
def forward(self, x, feat):
|
||||
x1 = self.in_conv(x, feat)
|
||||
x2 = self.down1(x1, feat)
|
||||
x3 = self.down2(x2, feat)
|
||||
x4 = self.down3(x3, feat)
|
||||
x5 = self.bottleneck(x4, feat)
|
||||
|
||||
x = self.up1(x5, x4, feat)
|
||||
x = self.up2(x, x3, feat)
|
||||
x = self.up3(x, x2, feat)
|
||||
x = self.up4(x, x1, feat)
|
||||
|
||||
return self.final(x)
|
||||
|
||||
@ -10,14 +10,12 @@ from .dataset import GinkaDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.args import parse_arguments
|
||||
|
||||
BATCH_SIZE = 32
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
|
||||
# 在生成器输出后添加梯度检查钩子
|
||||
def grad_hook(module, grad_input, grad_output):
|
||||
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
@ -29,21 +27,18 @@ def train():
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
# for param in minamo.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# 准备数据集
|
||||
dataset = GinkaDataset(args.train, device, minamo)
|
||||
dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val,
|
||||
batch_size=32,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
@ -52,9 +47,6 @@ def train():
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
# model.register_full_backward_hook(grad_hook)
|
||||
# converter.register_full_backward_hook(grad_hook)
|
||||
# criterion.register_full_backward_hook(grad_hook)
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"])
|
||||
@ -83,30 +75,20 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
_, output_softmax = model(feat_vec)
|
||||
noise = torch.randn((BATCH_SIZE, 1, 32, 32))
|
||||
_, output_softmax = model(noise, feat_vec)
|
||||
|
||||
# 计算损失
|
||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
scaled_losses.backward()
|
||||
losses.backward()
|
||||
optimizer.step()
|
||||
total_loss += losses.item()
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
# if p.grad is not None:
|
||||
# param_norm = p.grad.detach().data.norm(2)
|
||||
# total_norm += param_norm.item() ** 2
|
||||
# total_norm = total_norm ** 0.5
|
||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
@ -28,8 +29,12 @@ class MinamoDataset(Dataset):
|
||||
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
map1_probs = random_smooth_onehot(map1_probs)
|
||||
map2_probs = random_smooth_onehot(map2_probs)
|
||||
min_main = random.uniform(0.7, 1)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.3)
|
||||
|
||||
map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon)
|
||||
map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)
|
||||
|
||||
graph1 = differentiable_convert_to_data(map1_probs)
|
||||
graph2 = differentiable_convert_to_data(map2_probs)
|
||||
|
||||
@ -16,24 +16,6 @@ os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||
disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm
|
||||
|
||||
def collate_fn(batch):
|
||||
"""动态处理不同尺寸地图的批处理"""
|
||||
map1_batch = [item[0] for item in batch]
|
||||
map2_batch = [item[1] for item in batch]
|
||||
vis_sim = torch.cat([item[2] for item in batch])
|
||||
topo_sim = torch.cat([item[3] for item in batch])
|
||||
|
||||
# 保持批次内地图尺寸一致(根据问题描述)
|
||||
assert all(m.shape == map1_batch[0].shape for m in map1_batch), \
|
||||
"对比地图必须尺寸相同"
|
||||
|
||||
return (
|
||||
torch.stack(map1_batch), # (B, H, W)
|
||||
torch.stack(map2_batch), # (B, H, W)
|
||||
vis_sim,
|
||||
topo_sim
|
||||
)
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
@ -91,6 +73,9 @@ def train():
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
if map1.shape[0] == 1:
|
||||
continue
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user