feat: 改进生成器网络

This commit is contained in:
unanmed 2025-04-01 12:53:52 +08:00
parent 96f828e29b
commit 4721e9a141
7 changed files with 106 additions and 185 deletions

44
ginka/model/input.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torch.nn as nn
class GinkaInput(nn.Module):
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
nn.Unflatten(1, (out_ch, *size))
)
def forward(self, x):
x = self.fc(x)
return x
class FeatureEncoder(nn.Module):
def __init__(self, feat_dim, size, mid_ch, out_ch):
super().__init__()
self.encode = nn.Sequential(
nn.Linear(feat_dim, mid_ch * size * size),
nn.Unflatten(1, (mid_ch, size, size)),
nn.Conv2d(mid_ch, out_ch, 1)
)
def forward(self, x):
x = self.encode(x)
return x
class GinkaFeatureInput(nn.Module):
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
super().__init__()
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
def forward(self, x):
x1 = self.encode1(x)
x2 = self.encode2(x)
x3 = self.encode3(x)
x4 = self.encode4(x)
x5 = self.encode5(x)
return x1, x2, x3, x4, x5

View File

@ -86,7 +86,7 @@ def _create_distance_kernel(size):
dist = torch.sqrt((x - center)**2 + (y - center)**2)
kernel = 1 / (dist + 1)
kernel /= kernel.sum() # 归一化
return kernel.unsqueeze(0).unsqueeze(0), 1 / kernel.sum() # [1,1,H,W]
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
def entrance_constraint_loss(
pred: torch.Tensor,
@ -123,15 +123,13 @@ def entrance_constraint_loss(
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
# 概率密度感知的间距计算
kernel, cw = _create_distance_kernel(min_distance) # 自定义函数生成权重核
kernel = _create_distance_kernel(min_distance) # 自定义函数生成权重核
kernel = kernel.to(pred.device)
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
spacing_loss = density_map.mean()
###########################
# 区域加权综合损失
###########################
total_loss = (
lambda_presence * presence_loss +
lambda_spacing * (spacing_loss * center_weight).mean()
@ -225,116 +223,6 @@ def adaptive_count_loss(
return total_loss
def illegal_tile_loss(
pred_probs: torch.Tensor,
legal_classes: int = 13,
temperature: float = 0.1,
eps: float = 1e-8
) -> torch.Tensor:
"""
非法图块惩罚损失函数
参数:
pred_probs: 模型输出的概率分布 [B, C, H, W]
legal_classes: 合法图块数量0-based默认0-12为合法
temperature: 概率锐化温度系数0.1-1.0
eps: 数值稳定性保护
返回:
loss: 标量损失值
"""
B, C, H, W = pred_probs.shape
# 提取非法图块概率类别13及之后
illegal_probs = pred_probs[:, legal_classes:, :, :] # [B, C_illegal, H, W]
# 概率锐化(增强高概率区域的惩罚)
sharpened_probs = torch.exp(torch.log(illegal_probs + eps) / temperature)
sharpened_probs = sharpened_probs / (sharpened_probs.sum(dim=1, keepdim=True) + eps)
# 空间敏感权重(关注高置信度非法区域)
with torch.no_grad():
# 计算每个像素的非法概率置信度
confidence = illegal_probs.max(dim=1)[0] # [B, H, W]
# 生成注意力权重(高置信度区域权重加倍)
spatial_weights = 1 + torch.sigmoid(10*(confidence - 0.5))
# 逐像素计算非法概率损失
per_pixel_loss = torch.log(1 + illegal_probs.sum(dim=1)) # [B, H, W]
# 加权空间损失
weighted_loss = (per_pixel_loss * spatial_weights).mean()
# 类别平衡因子(抑制高频非法类别)
class_balance = 1 + torch.var(illegal_probs.mean(dim=(0,2,3))) # [C_illegal]
return weighted_loss * class_balance.mean()
def entrance_spatial_constraint(
pred_probs: torch.Tensor,
arrow_class: int = 11,
stair_class: int = 10,
border_width: int = 1,
lambda_arrow: float = 1.0,
lambda_stair: float = 1.0
) -> torch.Tensor:
"""
入口空间约束损失函数
参数:
pred_probs: 模型输出的概率分布 [B, C, H, W]
arrow_class: 箭头入口类别索引
stair_class: 楼梯入口类别索引
border_width: 边缘区域宽度默认1表示最外圈
lambda_arrow: 箭头约束权重
lambda_stair: 楼梯约束权重
返回:
loss: 标量损失值
"""
B, C, H, W = pred_probs.shape
##########################################
# 1. 区域掩码生成
##########################################
# 生成边缘区域掩码 [H, W]
edge_mask = torch.zeros((H, W), dtype=torch.bool, device=pred_probs.device)
# 上下边缘
edge_mask[:border_width, :] = True
edge_mask[-border_width:, :] = True
# 左右边缘(排除已标记的角落)
edge_mask[:, :border_width] = True
edge_mask[:, -border_width:] = True
# 生成中间区域掩码 [H, W]
center_mask = ~edge_mask
##########################################
# 2. 边缘区域约束(只能出现箭头)
##########################################
# 抑制边缘出现楼梯的概率 [B, N_edge_pixels]
edge_stair_probs = pred_probs[:, stair_class][:, edge_mask]
edge_stair_penalty = F.relu(edge_stair_probs - 0.1).mean() # 允许10%以下
##########################################
# 3. 中间区域约束(只能出现楼梯)
##########################################
# 抑制中间出现箭头的概率 [B, N_center_pixels]
center_arrow_probs = pred_probs[:, arrow_class][:, center_mask]
center_arrow_penalty = F.relu(center_arrow_probs - 0.1).mean() # 允许10%以下
##########################################
# 4. 综合损失
##########################################
total_loss = (
lambda_arrow * edge_stair_penalty +
lambda_stair * center_arrow_penalty
)
return total_loss
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
"""Ginka Model 损失函数部分
@ -353,7 +241,7 @@ class GinkaLoss(nn.Module):
def forward(self, pred, target, target_vision_feat, target_topo_feat):
# 地图结构损失
class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred)
entrance_loss = entrance_constraint_loss(pred) + entrance_spatial_constraint(pred)
entrance_loss = entrance_constraint_loss(pred)
count_loss = adaptive_count_loss(pred, target)
# 使用 Minamo Model 计算相似度

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .output import GinkaOutput
from .input import GinkaInput, GinkaFeatureInput
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
@ -12,23 +13,27 @@ class GinkaModel(nn.Module):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.unet = GinkaUNet(1, base_ch, out_ch, feat_dim)
self.output = GinkaOutput(out_ch, out_ch, (13, 13))
self.input = GinkaInput(feat_dim, 1, (32, 32))
self.feat_enc = GinkaFeatureInput(feat_dim, 2, base_ch)
self.unet = GinkaUNet(1, base_ch, base_ch, feat_dim)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, feat):
def forward(self, x):
"""
Args:
feat: 参考地图的特征向量
x: 参考地图的特征向量
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
x = self.unet(x, feat)
cond = x
feat = self.feat_enc(x)
x = self.input(x)
x = self.unet(x, feat, cond)
x = self.output(x)
return x, F.softmax(x, dim=1)
# 检查显存占用
if __name__ == "__main__":
x = torch.randn((1, 1, 32, 32)).cuda()
feat = torch.randn((1, 1024)).cuda()
# 初始化模型
@ -37,12 +42,14 @@ if __name__ == "__main__":
print_memory("初始化后")
# 前向传播
output, output_softmax = model(x, feat)
output, output_softmax = model(feat)
print_memory("前向传播后")
print(f"输入形状: x={x.shape}, feat={feat.shape}")
print(f"输入形状: feat={feat.shape}")
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -2,11 +2,11 @@ import torch
import torch.nn as nn
class GinkaOutput(nn.Module):
def __init__(self, out_ch=32, base_ch=64, out_size=(13, 13)):
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
super().__init__()
self.conv_down = nn.Sequential(
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(base_ch, out_ch, 1)
nn.Conv2d(in_ch, out_ch, 1)
)
def forward(self, x):

View File

@ -2,29 +2,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
class GinkaAdaIN(nn.Module):
def __init__(self, num_features, condition_dim):
"""
自适应实例归一化 (AdaIN)
参数:
num_features: 归一化的通道数
condition_dim: 条件输入的特征维度
"""
super(GinkaAdaIN, self).__init__()
self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β
def forward(self, x, condition):
"""
x: [B, C, H, W] - 输入特征图
condition: [B, condition_dim] - 需要注入的条件向量
"""
gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β
gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状
beta = beta.view(x.shape[0], x.shape[1], 1, 1)
x = F.instance_norm(x) # 标准化
return gamma * x + beta # 进行变换
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
@ -40,29 +17,37 @@ class ConvBlock(nn.Module):
def forward(self, x):
return self.conv(x)
class AdaINConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, feat_dim):
class ConditionFusionBlock(nn.Module):
def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数
def forward(self, x, cond_feat):
return x + self.alpha * cond_feat # 残差融合
class FusionConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.adain = GinkaAdaIN(out_ch, feat_dim)
self.fusion = ConditionFusionBlock()
def forward(self, x, feat):
x = self.conv(x)
x = self.adain(x, feat)
x = self.fusion(x, feat)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch, feat_dim):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.adain = GinkaAdaIN(out_ch, feat_dim)
self.fusion = ConditionFusionBlock()
def forward(self, x, feat):
x = self.conv(x)
x = self.pool(x)
x = self.adain(x, feat)
x = self.fusion(x, feat)
return x
class GinkaUpSample(nn.Module):
@ -79,17 +64,17 @@ class GinkaUpSample(nn.Module):
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_ch, out_ch, feat_dim):
def __init__(self, in_ch, out_ch):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.adain = GinkaAdaIN(out_ch, feat_dim)
self.fusion = ConditionFusionBlock()
def forward(self, x, skip, feat):
x = self.upsample(x)
x = torch.cat([x, skip], dim=1)
dec = self.upsample(x)
x = torch.cat([dec, skip], dim=1)
x = self.conv(x)
x = self.adain(x, feat)
x = self.fusion(x, feat)
return x
class GinkaUNet(nn.Module):
@ -97,32 +82,32 @@ class GinkaUNet(nn.Module):
"""Ginka Model UNet 部分
"""
super().__init__()
self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim)
self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim)
self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim)
self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim)
self.in_conv = FusionConvBlock(in_ch, base_ch)
self.down1 = GinkaEncoder(base_ch, base_ch*2)
self.down2 = GinkaEncoder(base_ch*2, base_ch*4)
self.down3 = GinkaEncoder(base_ch*4, base_ch*8)
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim)
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16)
self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim)
self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim)
self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim)
self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim)
self.up1 = GinkaDecoder(base_ch*16, base_ch*8)
self.up2 = GinkaDecoder(base_ch*8, base_ch*4)
self.up3 = GinkaDecoder(base_ch*4, base_ch*2)
self.up4 = GinkaDecoder(base_ch*2, base_ch)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
)
def forward(self, x, feat):
x1 = self.in_conv(x, feat)
x2 = self.down1(x1, feat)
x3 = self.down2(x2, feat)
x4 = self.down3(x3, feat)
x5 = self.bottleneck(x4, feat)
def forward(self, x, feat, cond):
x1 = self.in_conv(x, feat[0])
x2 = self.down1(x1, feat[1])
x3 = self.down2(x2, feat[2])
x4 = self.down3(x3, feat[3])
x5 = self.bottleneck(x4, feat[4])
x = self.up1(x5, x4, feat)
x = self.up2(x, x3, feat)
x = self.up3(x, x2, feat)
x = self.up4(x, x1, feat)
x = self.up1(x5, x4, feat[3])
x = self.up2(x, x3, feat[2])
x = self.up3(x, x2, feat[1])
x = self.up4(x, x1, feat[0])
return self.final(x)

View File

@ -75,8 +75,7 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
optimizer.zero_grad()
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
_, output_softmax = model(noise, feat_vec)
_, output_softmax = model(feat_vec)
# 计算损失
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
@ -108,8 +107,7 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
output, output_softmax = model(noise, feat_vec)
output, output_softmax = model(feat_vec)
print(torch.argmax(output, dim=1)[0])
# 计算损失

View File

@ -108,8 +108,7 @@ def validate():
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
output, output_softmax = model(noise, feat_vec)
output, output_softmax = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
for matrix in map_matrix[:].cpu():