refactor: UNet 部分重写并改为条件注入模式

This commit is contained in:
unanmed 2025-03-30 13:48:01 +08:00
parent 49ee543732
commit 5669f49af0
8 changed files with 183 additions and 207 deletions

2
gan.sh
View File

@ -2,7 +2,7 @@
python3 -m minamo.train --epochs 10 --resume true
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
python3 -m minamo.train --epochs 10 --resume true
python3 -m ginka.train --epochs 30 --resume true
python3 -m ginka.train --epochs 70 --resume true
python3 -m ginka.validate
# 训练完毕,处理数据
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"

View File

@ -1,41 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualUpsampleBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
nn.GELU()
)
def forward(self, x):
return self.conv(x)
class GinkaInput(nn.Module):
def __init__(self, feat_dim=1024, out_ch=64):
super().__init__()
fc_dim = out_ch * 8 * 4 * 4
self.out_ch = out_ch
self.fc = nn.Sequential(
nn.Linear(feat_dim, fc_dim),
nn.BatchNorm1d(fc_dim),
nn.ReLU()
)
self.upsample = nn.Sequential(
ResidualUpsampleBlock(out_ch*8, out_ch*8),
ResidualUpsampleBlock(out_ch*8, out_ch*4),
ResidualUpsampleBlock(out_ch*4, out_ch)
)
def forward(self, x):
x = self.fc(x)
x = x.view(-1, self.out_ch*8, 4, 4)
x = self.upsample(x)
return x

View File

@ -117,71 +117,77 @@ def adaptive_count_loss(
class_list: list = list(range(32)),
margin_ratio: float = 0.2,
zero_margin_scale: float = 0.2,
lambda_entropy: float = 0.05,
lambda_local: float = 0.1,
grid_size: int = 8,
eps: float = 1e-3
) -> torch.Tensor:
"""
自适应图块数量约束损失函数
改进版自适应图块数量约束损失包含局部匹配和熵约束
参数:
pred_probs: 预测概率分布 [B, C, H, W]
target_map: 真实地图 [B, C, H, W]
class_list: 需要约束的类别列表
margin_ratio: 允许的相对误差范围如0.2表示±20%
zero_margin_scale: 参考数量为0时的允许余量系数余量=scale*sqrt(H*W))
eps: 数值稳定性常数
返回:
loss: 标量损失值
"""
B, C, H, W = pred_probs.shape
device = pred_probs.device
total_loss = 0.0
valid_classes = 0
# 预计算地图面积用于余量计算
# 预计算地图面积
map_area = math.sqrt(H * W)
# 计算最小非零类别概率
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() # 获取预测中的最小非零概率
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area # 让零类别不被填充
for cls in class_list:
# 预测数量(概率和)
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # [B]
# 真实数量
true_count = target_map[:, cls].sum(dim=(1,2)) # [B]
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测类别数量
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实类别数量
# 动态容差计算
with torch.no_grad():
# 当真实数量为0时的允许上限
zero_mask = (true_count == 0)
dynamic_margin = torch.where(
zero_mask,
zero_margin_scale * map_area, # 允许存在少量
margin_ratio * true_count # 相对误差范围
)
zero_mask = (true_count == 0)
dynamic_margin = torch.where(
zero_mask,
dynamic_zero_margin,
margin_ratio * true_count
)
# 误差计算(考虑数值稳定性)
safe_true = true_count + eps * zero_mask # 零真实值时添加微小量
safe_true = true_count + eps * zero_mask
abs_error = torch.abs(pred_count - true_count)
rel_error = abs_error / safe_true
# 双阶段损失函数
# 阶段一:误差在容差范围内时使用二次函数(强梯度)
# 阶段二:超出容差时转为线性(稳定训练)
loss_per_class = torch.where(
abs_error <= dynamic_margin,
(rel_error ** 2) * 0.5, # 区间内强梯度
rel_error - (0.5 * margin_ratio) # 区间外稳定梯度
(rel_error ** 2) * 0.8 + 0.2 * rel_error,
rel_error - (0.5 * margin_ratio)
)
# 零真实值特殊处理:仅惩罚超出余量部分
loss_per_class = torch.where(
zero_mask,
F.relu(abs_error - dynamic_margin) / map_area, # 归一化处理
F.relu(abs_error - dynamic_margin) / map_area,
loss_per_class
)
total_loss += loss_per_class.mean()
valid_classes += 1
return total_loss / valid_classes # 类别平均
# 平均类别损失
total_loss /= valid_classes
# 加入负熵约束,防止类别均匀化
def entropy_loss(pred_probs):
avg_probs = pred_probs.mean(dim=(2, 3))
entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=1)
return entropy.mean()
total_loss += lambda_entropy * entropy_loss(pred_probs)
# 加入局部类别匹配
def local_count_loss(pred_probs, target_probs, grid_size=8):
pred_local = F.avg_pool2d(pred_probs, kernel_size=grid_size, stride=grid_size)
target_local = F.avg_pool2d(target_probs, kernel_size=grid_size, stride=grid_size)
return F.mse_loss(pred_local, target_local)
total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size)
return total_loss
def illegal_tile_loss(
pred_probs: torch.Tensor,

View File

@ -2,31 +2,48 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .input import GinkaInput
from .output import GinkaOutput
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
class GinkaModel(nn.Module):
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.input = GinkaInput(feat_dim, base_ch)
self.unet = GinkaUNet(base_ch, num_classes)
self.unet = GinkaUNet(1, base_ch, num_classes, feat_dim)
self.output = GinkaOutput(num_classes, (13, 13))
print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
def forward(self, x):
def forward(self, x, feat):
"""
Args:
feat: 参考地图的特征向量
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
x = self.input(x)
x = self.unet(x)
x = self.unet(x, feat)
x = self.output(x)
return x, F.softmax(x, dim=1)
# 检查显存占用
if __name__ == "__main__":
x = torch.randn((1, 1, 32, 32)).cuda()
feat = torch.randn((1, 1024)).cuda()
# 初始化模型
model = GinkaModel().cuda()
print_memory("初始化后")
# 前向传播
output, output_softmax = model(x, feat)
print_memory("前向传播后")
print(f"输入形状: x={x.shape}, feat={feat.shape}")
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,106 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.attention import CBAM, SEBlock
class GinkaAdaIN(nn.Module):
def __init__(self, num_features, condition_dim):
"""
自适应实例归一化 (AdaIN)
参数:
num_features: 归一化的通道数
condition_dim: 条件输入的特征维度
"""
super(GinkaAdaIN, self).__init__()
self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β
def forward(self, x, condition):
"""
x: [B, C, H, W] - 输入特征图
condition: [B, condition_dim] - 需要注入的条件向量
"""
gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β
gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状
beta = beta.view(x.shape[0], x.shape[1], 1, 1)
x = F.instance_norm(x) # 标准化
return gamma * x + beta # 进行变换
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class AdaINConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, feat_dim):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.adain = GinkaAdaIN(out_ch, feat_dim)
def forward(self, x, feat):
x = self.conv(x)
x = self.adain(x, feat)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_channels, out_channels, attention=False, block='CBAM'):
def __init__(self, in_ch, out_ch, feat_dim):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.adain = GinkaAdaIN(out_ch, feat_dim)
def forward(self, x, feat):
x = self.conv(x)
x = self.pool(x)
x = self.adain(x, feat)
return x
class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.BatchNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
)
# 注意力
if attention:
if block == 'CBAM':
self.conv.append(CBAM(out_channels))
elif block == 'SEBlock':
self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU())
self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
def forward(self, x):
x_res = self.conv(x)
x_down = self.down(x_res)
return x_down, x_res
return self.conv(x)
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_channels, out_channels, attention=False, block='CBAM'):
def __init__(self, in_ch, out_ch, feat_dim):
super().__init__()
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.conv = nn.Sequential(
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
)
# 注意力
if attention:
if block == 'CBAM':
self.conv.append(CBAM(out_channels))
elif block == 'SEBlock':
self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU())
def forward(self, x, skip):
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.adain = GinkaAdaIN(out_ch, feat_dim)
def forward(self, x, skip, feat):
x = self.upsample(x)
x = torch.cat([x, skip], dim=1)
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
x = self.adain(x, feat)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, in_channels, out_channels, attention=False):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
)
if attention:
self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU())
def forward(self, x):
return self.conv(x)
class GinkaUNet(nn.Module):
def __init__(self, in_ch=64, out_ch=32):
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
"""Ginka Model UNet 部分
"""
super().__init__()
self.down1 = GinkaEncoder(in_ch, in_ch*2, attention=True)
self.down2 = GinkaEncoder(in_ch*2, in_ch*4, attention=True)
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim)
self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim)
self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim)
self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim)
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim)
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
self.up3 = GinkaDecoder(in_ch*4, in_ch*2, attention=True)
self.up4 = GinkaDecoder(in_ch*2, in_ch, attention=True)
self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim)
self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim)
self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim)
self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim)
self.final = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1),
nn.Conv2d(base_ch, out_ch, 1),
)
def forward(self, x):
x_down1, skip1 = self.down1(x)
x_down2, skip2 = self.down2(x_down1)
x_down3, skip3 = self.down3(x_down2)
x_down4, skip4 = self.down4(x_down3)
x = self.bottleneck(x_down4)
x = self.up1(x, skip4)
x = self.up2(x, skip3)
x = self.up3(x, skip2)
x = self.up4(x, skip1)
def forward(self, x, feat):
x1 = self.in_conv(x, feat)
x2 = self.down1(x1, feat)
x3 = self.down2(x2, feat)
x4 = self.down3(x3, feat)
x5 = self.bottleneck(x4, feat)
x = self.up1(x5, x4, feat)
x = self.up2(x, x3, feat)
x = self.up3(x, x2, feat)
x = self.up4(x, x1, feat)
return self.final(x)

View File

@ -10,14 +10,12 @@ from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
from shared.args import parse_arguments
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
# 在生成器输出后添加梯度检查钩子
def grad_hook(module, grad_input, grad_output):
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
@ -29,21 +27,18 @@ def train():
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
minamo.eval()
# for param in minamo.parameters():
# param.requires_grad = False
# 准备数据集
dataset = GinkaDataset(args.train, device, minamo)
dataset_val = GinkaDataset(args.validate, device, minamo)
dataloader = DataLoader(
dataset,
batch_size=32,
batch_size=BATCH_SIZE,
shuffle=True
)
dataloader_val = DataLoader(
dataset_val,
batch_size=32,
batch_size=BATCH_SIZE,
shuffle=True
)
@ -52,9 +47,6 @@ def train():
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
# model.register_full_backward_hook(grad_hook)
# converter.register_full_backward_hook(grad_hook)
# criterion.register_full_backward_hook(grad_hook)
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
@ -83,30 +75,20 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
optimizer.zero_grad()
_, output_softmax = model(feat_vec)
noise = torch.randn((BATCH_SIZE, 1, 32, 32))
_, output_softmax = model(noise, feat_vec)
# 计算损失
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
scaled_losses.backward()
losses.backward()
optimizer.step()
total_loss += losses.item()
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0
# for p in model.parameters():
# if p.grad is not None:
# param_norm = p.grad.detach().data.norm(2)
# total_norm += param_norm.item() ** 2
# total_norm = total_norm ** 0.5
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")

View File

@ -1,4 +1,5 @@
import json
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
@ -28,8 +29,12 @@ class MinamoDataset(Dataset):
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map1_probs = random_smooth_onehot(map1_probs)
map2_probs = random_smooth_onehot(map2_probs)
min_main = random.uniform(0.7, 1)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.3)
map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon)
map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)
graph1 = differentiable_convert_to_data(map1_probs)
graph2 = differentiable_convert_to_data(map2_probs)

View File

@ -16,24 +16,6 @@ os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True)
disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm
def collate_fn(batch):
"""动态处理不同尺寸地图的批处理"""
map1_batch = [item[0] for item in batch]
map2_batch = [item[1] for item in batch]
vis_sim = torch.cat([item[2] for item in batch])
topo_sim = torch.cat([item[3] for item in batch])
# 保持批次内地图尺寸一致(根据问题描述)
assert all(m.shape == map1_batch[0].shape for m in map1_batch), \
"对比地图必须尺寸相同"
return (
torch.stack(map1_batch), # (B, H, W)
torch.stack(map2_batch), # (B, H, W)
vis_sim,
topo_sim
)
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
@ -91,6 +73,9 @@ def train():
graph1 = graph1.to(device)
graph2 = graph2.to(device)
if map1.shape[0] == 1:
continue
# 前向传播
optimizer.zero_grad()
vision_feat1, topo_feat1 = model(map1, graph1)