diff --git a/ginka/model/input.py b/ginka/model/input.py new file mode 100644 index 0000000..7796367 --- /dev/null +++ b/ginka/model/input.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + +class GinkaInput(nn.Module): + def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(feat_dim, size[0] * size[1] * out_ch), + nn.Unflatten(1, (out_ch, *size)) + ) + + def forward(self, x): + x = self.fc(x) + return x + +class FeatureEncoder(nn.Module): + def __init__(self, feat_dim, size, mid_ch, out_ch): + super().__init__() + self.encode = nn.Sequential( + nn.Linear(feat_dim, mid_ch * size * size), + nn.Unflatten(1, (mid_ch, size, size)), + nn.Conv2d(mid_ch, out_ch, 1) + ) + + def forward(self, x): + x = self.encode(x) + return x + +class GinkaFeatureInput(nn.Module): + def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64): + super().__init__() + self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch) + self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2) + self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4) + self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8) + self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16) + + def forward(self, x): + x1 = self.encode1(x) + x2 = self.encode2(x) + x3 = self.encode3(x) + x4 = self.encode4(x) + x5 = self.encode5(x) + return x1, x2, x3, x4, x5 diff --git a/ginka/model/loss.py b/ginka/model/loss.py index a083750..ab7294e 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -86,7 +86,7 @@ def _create_distance_kernel(size): dist = torch.sqrt((x - center)**2 + (y - center)**2) kernel = 1 / (dist + 1) kernel /= kernel.sum() # 归一化 - return kernel.unsqueeze(0).unsqueeze(0), 1 / kernel.sum() # [1,1,H,W] + return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W] def entrance_constraint_loss( pred: torch.Tensor, @@ -123,15 +123,13 @@ def entrance_constraint_loss( center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W] # 概率密度感知的间距计算 - kernel, cw = _create_distance_kernel(min_distance) # 自定义函数生成权重核 + kernel = _create_distance_kernel(min_distance) # 自定义函数生成权重核 kernel = kernel.to(pred.device) density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1) spacing_loss = density_map.mean() - ########################### # 区域加权综合损失 - ########################### total_loss = ( lambda_presence * presence_loss + lambda_spacing * (spacing_loss * center_weight).mean() @@ -225,116 +223,6 @@ def adaptive_count_loss( return total_loss -def illegal_tile_loss( - pred_probs: torch.Tensor, - legal_classes: int = 13, - temperature: float = 0.1, - eps: float = 1e-8 -) -> torch.Tensor: - """ - 非法图块惩罚损失函数 - - 参数: - pred_probs: 模型输出的概率分布 [B, C, H, W] - legal_classes: 合法图块数量(0-based,默认0-12为合法) - temperature: 概率锐化温度系数(0.1-1.0) - eps: 数值稳定性保护 - - 返回: - loss: 标量损失值 - """ - B, C, H, W = pred_probs.shape - - # 提取非法图块概率(类别13及之后) - illegal_probs = pred_probs[:, legal_classes:, :, :] # [B, C_illegal, H, W] - - # 概率锐化(增强高概率区域的惩罚) - sharpened_probs = torch.exp(torch.log(illegal_probs + eps) / temperature) - sharpened_probs = sharpened_probs / (sharpened_probs.sum(dim=1, keepdim=True) + eps) - - # 空间敏感权重(关注高置信度非法区域) - with torch.no_grad(): - # 计算每个像素的非法概率置信度 - confidence = illegal_probs.max(dim=1)[0] # [B, H, W] - # 生成注意力权重(高置信度区域权重加倍) - spatial_weights = 1 + torch.sigmoid(10*(confidence - 0.5)) - - # 逐像素计算非法概率损失 - per_pixel_loss = torch.log(1 + illegal_probs.sum(dim=1)) # [B, H, W] - - # 加权空间损失 - weighted_loss = (per_pixel_loss * spatial_weights).mean() - - # 类别平衡因子(抑制高频非法类别) - class_balance = 1 + torch.var(illegal_probs.mean(dim=(0,2,3))) # [C_illegal] - - return weighted_loss * class_balance.mean() - -def entrance_spatial_constraint( - pred_probs: torch.Tensor, - arrow_class: int = 11, - stair_class: int = 10, - border_width: int = 1, - lambda_arrow: float = 1.0, - lambda_stair: float = 1.0 -) -> torch.Tensor: - """ - 入口空间约束损失函数 - - 参数: - pred_probs: 模型输出的概率分布 [B, C, H, W] - arrow_class: 箭头入口类别索引 - stair_class: 楼梯入口类别索引 - border_width: 边缘区域宽度(默认1表示最外圈) - lambda_arrow: 箭头约束权重 - lambda_stair: 楼梯约束权重 - - 返回: - loss: 标量损失值 - """ - B, C, H, W = pred_probs.shape - - ########################################## - # 1. 区域掩码生成 - ########################################## - # 生成边缘区域掩码 [H, W] - edge_mask = torch.zeros((H, W), dtype=torch.bool, device=pred_probs.device) - # 上下边缘 - edge_mask[:border_width, :] = True - edge_mask[-border_width:, :] = True - # 左右边缘(排除已标记的角落) - edge_mask[:, :border_width] = True - edge_mask[:, -border_width:] = True - - # 生成中间区域掩码 [H, W] - center_mask = ~edge_mask - - ########################################## - # 2. 边缘区域约束(只能出现箭头) - ########################################## - - # 抑制边缘出现楼梯的概率 [B, N_edge_pixels] - edge_stair_probs = pred_probs[:, stair_class][:, edge_mask] - edge_stair_penalty = F.relu(edge_stair_probs - 0.1).mean() # 允许10%以下 - - ########################################## - # 3. 中间区域约束(只能出现楼梯) - ########################################## - - # 抑制中间出现箭头的概率 [B, N_center_pixels] - center_arrow_probs = pred_probs[:, arrow_class][:, center_mask] - center_arrow_penalty = F.relu(center_arrow_probs - 0.1).mean() # 允许10%以下 - - ########################################## - # 4. 综合损失 - ########################################## - total_loss = ( - lambda_arrow * edge_stair_penalty + - lambda_stair * center_arrow_penalty - ) - - return total_loss - class GinkaLoss(nn.Module): def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]): """Ginka Model 损失函数部分 @@ -353,7 +241,7 @@ class GinkaLoss(nn.Module): def forward(self, pred, target, target_vision_feat, target_topo_feat): # 地图结构损失 class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred) - entrance_loss = entrance_constraint_loss(pred) + entrance_spatial_constraint(pred) + entrance_loss = entrance_constraint_loss(pred) count_loss = adaptive_count_loss(pred, target) # 使用 Minamo Model 计算相似度 diff --git a/ginka/model/model.py b/ginka/model/model.py index b534ff3..0c75ba4 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from .unet import GinkaUNet from .output import GinkaOutput +from .input import GinkaInput, GinkaFeatureInput def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") @@ -12,23 +13,27 @@ class GinkaModel(nn.Module): """Ginka Model 模型定义部分 """ super().__init__() - self.unet = GinkaUNet(1, base_ch, out_ch, feat_dim) - self.output = GinkaOutput(out_ch, out_ch, (13, 13)) + self.input = GinkaInput(feat_dim, 1, (32, 32)) + self.feat_enc = GinkaFeatureInput(feat_dim, 2, base_ch) + self.unet = GinkaUNet(1, base_ch, base_ch, feat_dim) + self.output = GinkaOutput(base_ch, out_ch, (13, 13)) - def forward(self, x, feat): + def forward(self, x): """ Args: - feat: 参考地图的特征向量 + x: 参考地图的特征向量 Returns: logits: 输出logits [BS, num_classes, H, W] """ - x = self.unet(x, feat) + cond = x + feat = self.feat_enc(x) + x = self.input(x) + x = self.unet(x, feat, cond) x = self.output(x) return x, F.softmax(x, dim=1) # 检查显存占用 if __name__ == "__main__": - x = torch.randn((1, 1, 32, 32)).cuda() feat = torch.randn((1, 1024)).cuda() # 初始化模型 @@ -37,12 +42,14 @@ if __name__ == "__main__": print_memory("初始化后") # 前向传播 - output, output_softmax = model(x, feat) + output, output_softmax = model(feat) print_memory("前向传播后") - print(f"输入形状: x={x.shape}, feat={feat.shape}") + print(f"输入形状: feat={feat.shape}") print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}") + print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") + print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/model/output.py b/ginka/model/output.py index 7e34208..11f77ed 100644 --- a/ginka/model/output.py +++ b/ginka/model/output.py @@ -2,11 +2,11 @@ import torch import torch.nn as nn class GinkaOutput(nn.Module): - def __init__(self, out_ch=32, base_ch=64, out_size=(13, 13)): + def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)): super().__init__() self.conv_down = nn.Sequential( nn.AdaptiveMaxPool2d(out_size), - nn.Conv2d(base_ch, out_ch, 1) + nn.Conv2d(in_ch, out_ch, 1) ) def forward(self, x): diff --git a/ginka/model/unet.py b/ginka/model/unet.py index 3b7484d..e6de53b 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -2,29 +2,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -class GinkaAdaIN(nn.Module): - def __init__(self, num_features, condition_dim): - """ - 自适应实例归一化 (AdaIN) - 参数: - num_features: 归一化的通道数 - condition_dim: 条件输入的特征维度 - """ - super(GinkaAdaIN, self).__init__() - self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β - - def forward(self, x, condition): - """ - x: [B, C, H, W] - 输入特征图 - condition: [B, condition_dim] - 需要注入的条件向量 - """ - gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β - gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状 - beta = beta.view(x.shape[0], x.shape[1], 1, 1) - - x = F.instance_norm(x) # 标准化 - return gamma * x + beta # 进行变换 - class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() @@ -40,29 +17,37 @@ class ConvBlock(nn.Module): def forward(self, x): return self.conv(x) -class AdaINConvBlock(nn.Module): - def __init__(self, in_ch, out_ch, feat_dim): +class ConditionFusionBlock(nn.Module): + def __init__(self): + super().__init__() + self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数 + + def forward(self, x, cond_feat): + return x + self.alpha * cond_feat # 残差融合 + +class FusionConvBlock(nn.Module): + def __init__(self, in_ch, out_ch): super().__init__() self.conv = ConvBlock(in_ch, out_ch) - self.adain = GinkaAdaIN(out_ch, feat_dim) + self.fusion = ConditionFusionBlock() def forward(self, x, feat): x = self.conv(x) - x = self.adain(x, feat) + x = self.fusion(x, feat) return x class GinkaEncoder(nn.Module): """编码器(下采样)部分""" - def __init__(self, in_ch, out_ch, feat_dim): + def __init__(self, in_ch, out_ch): super().__init__() self.conv = ConvBlock(in_ch, out_ch) self.pool = nn.MaxPool2d(2) - self.adain = GinkaAdaIN(out_ch, feat_dim) + self.fusion = ConditionFusionBlock() def forward(self, x, feat): x = self.conv(x) x = self.pool(x) - x = self.adain(x, feat) + x = self.fusion(x, feat) return x class GinkaUpSample(nn.Module): @@ -79,17 +64,17 @@ class GinkaUpSample(nn.Module): class GinkaDecoder(nn.Module): """解码器(上采样)部分""" - def __init__(self, in_ch, out_ch, feat_dim): + def __init__(self, in_ch, out_ch): super().__init__() self.upsample = GinkaUpSample(in_ch, in_ch // 2) self.conv = ConvBlock(in_ch, out_ch) - self.adain = GinkaAdaIN(out_ch, feat_dim) + self.fusion = ConditionFusionBlock() def forward(self, x, skip, feat): - x = self.upsample(x) - x = torch.cat([x, skip], dim=1) + dec = self.upsample(x) + x = torch.cat([dec, skip], dim=1) x = self.conv(x) - x = self.adain(x, feat) + x = self.fusion(x, feat) return x class GinkaUNet(nn.Module): @@ -97,32 +82,32 @@ class GinkaUNet(nn.Module): """Ginka Model UNet 部分 """ super().__init__() - self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim) - self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim) - self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim) - self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim) + self.in_conv = FusionConvBlock(in_ch, base_ch) + self.down1 = GinkaEncoder(base_ch, base_ch*2) + self.down2 = GinkaEncoder(base_ch*2, base_ch*4) + self.down3 = GinkaEncoder(base_ch*4, base_ch*8) - self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim) + self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16) - self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim) - self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim) - self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim) - self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim) + self.up1 = GinkaDecoder(base_ch*16, base_ch*8) + self.up2 = GinkaDecoder(base_ch*8, base_ch*4) + self.up3 = GinkaDecoder(base_ch*4, base_ch*2) + self.up4 = GinkaDecoder(base_ch*2, base_ch) self.final = nn.Sequential( nn.Conv2d(base_ch, out_ch, 1), ) - def forward(self, x, feat): - x1 = self.in_conv(x, feat) - x2 = self.down1(x1, feat) - x3 = self.down2(x2, feat) - x4 = self.down3(x3, feat) - x5 = self.bottleneck(x4, feat) + def forward(self, x, feat, cond): + x1 = self.in_conv(x, feat[0]) + x2 = self.down1(x1, feat[1]) + x3 = self.down2(x2, feat[2]) + x4 = self.down3(x3, feat[3]) + x5 = self.bottleneck(x4, feat[4]) - x = self.up1(x5, x4, feat) - x = self.up2(x, x3, feat) - x = self.up3(x, x2, feat) - x = self.up4(x, x1, feat) + x = self.up1(x5, x4, feat[3]) + x = self.up2(x, x3, feat[2]) + x = self.up3(x, x2, feat[1]) + x = self.up4(x, x1, feat[0]) return self.final(x) diff --git a/ginka/train.py b/ginka/train.py index 34ef1c5..86ff5cb 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -75,8 +75,7 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 optimizer.zero_grad() - noise = torch.randn((target.shape[0], 1, 32, 32)).to(device) - _, output_softmax = model(noise, feat_vec) + _, output_softmax = model(feat_vec) # 计算损失 scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) @@ -108,8 +107,7 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 - noise = torch.randn((target.shape[0], 1, 32, 32)).to(device) - output, output_softmax = model(noise, feat_vec) + output, output_softmax = model(feat_vec) print(torch.argmax(output, dim=1)[0]) # 计算损失 diff --git a/ginka/validate.py b/ginka/validate.py index 750707f..89ef941 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -108,8 +108,7 @@ def validate(): target_topo_feat = batch["target_topo_feat"].to(device) feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 - noise = torch.randn((target.shape[0], 1, 32, 32)).to(device) - output, output_softmax = model(noise, feat_vec) + output, output_softmax = model(feat_vec) map_matrix = torch.argmax(output, dim=1) for matrix in map_matrix[:].cpu():