mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 改进生成器网络
This commit is contained in:
parent
96f828e29b
commit
4721e9a141
44
ginka/model/input.py
Normal file
44
ginka/model/input.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaInput(nn.Module):
|
||||
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
|
||||
nn.Unflatten(1, (out_ch, *size))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class FeatureEncoder(nn.Module):
|
||||
def __init__(self, feat_dim, size, mid_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode = nn.Sequential(
|
||||
nn.Linear(feat_dim, mid_ch * size * size),
|
||||
nn.Unflatten(1, (mid_ch, size, size)),
|
||||
nn.Conv2d(mid_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encode(x)
|
||||
return x
|
||||
|
||||
class GinkaFeatureInput(nn.Module):
|
||||
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
|
||||
super().__init__()
|
||||
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
|
||||
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
|
||||
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
|
||||
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
|
||||
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.encode1(x)
|
||||
x2 = self.encode2(x)
|
||||
x3 = self.encode3(x)
|
||||
x4 = self.encode4(x)
|
||||
x5 = self.encode5(x)
|
||||
return x1, x2, x3, x4, x5
|
||||
@ -86,7 +86,7 @@ def _create_distance_kernel(size):
|
||||
dist = torch.sqrt((x - center)**2 + (y - center)**2)
|
||||
kernel = 1 / (dist + 1)
|
||||
kernel /= kernel.sum() # 归一化
|
||||
return kernel.unsqueeze(0).unsqueeze(0), 1 / kernel.sum() # [1,1,H,W]
|
||||
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
|
||||
|
||||
def entrance_constraint_loss(
|
||||
pred: torch.Tensor,
|
||||
@ -123,15 +123,13 @@ def entrance_constraint_loss(
|
||||
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
|
||||
|
||||
# 概率密度感知的间距计算
|
||||
kernel, cw = _create_distance_kernel(min_distance) # 自定义函数生成权重核
|
||||
kernel = _create_distance_kernel(min_distance) # 自定义函数生成权重核
|
||||
kernel = kernel.to(pred.device)
|
||||
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
|
||||
|
||||
spacing_loss = density_map.mean()
|
||||
|
||||
###########################
|
||||
# 区域加权综合损失
|
||||
###########################
|
||||
total_loss = (
|
||||
lambda_presence * presence_loss +
|
||||
lambda_spacing * (spacing_loss * center_weight).mean()
|
||||
@ -225,116 +223,6 @@ def adaptive_count_loss(
|
||||
|
||||
return total_loss
|
||||
|
||||
def illegal_tile_loss(
|
||||
pred_probs: torch.Tensor,
|
||||
legal_classes: int = 13,
|
||||
temperature: float = 0.1,
|
||||
eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
非法图块惩罚损失函数
|
||||
|
||||
参数:
|
||||
pred_probs: 模型输出的概率分布 [B, C, H, W]
|
||||
legal_classes: 合法图块数量(0-based,默认0-12为合法)
|
||||
temperature: 概率锐化温度系数(0.1-1.0)
|
||||
eps: 数值稳定性保护
|
||||
|
||||
返回:
|
||||
loss: 标量损失值
|
||||
"""
|
||||
B, C, H, W = pred_probs.shape
|
||||
|
||||
# 提取非法图块概率(类别13及之后)
|
||||
illegal_probs = pred_probs[:, legal_classes:, :, :] # [B, C_illegal, H, W]
|
||||
|
||||
# 概率锐化(增强高概率区域的惩罚)
|
||||
sharpened_probs = torch.exp(torch.log(illegal_probs + eps) / temperature)
|
||||
sharpened_probs = sharpened_probs / (sharpened_probs.sum(dim=1, keepdim=True) + eps)
|
||||
|
||||
# 空间敏感权重(关注高置信度非法区域)
|
||||
with torch.no_grad():
|
||||
# 计算每个像素的非法概率置信度
|
||||
confidence = illegal_probs.max(dim=1)[0] # [B, H, W]
|
||||
# 生成注意力权重(高置信度区域权重加倍)
|
||||
spatial_weights = 1 + torch.sigmoid(10*(confidence - 0.5))
|
||||
|
||||
# 逐像素计算非法概率损失
|
||||
per_pixel_loss = torch.log(1 + illegal_probs.sum(dim=1)) # [B, H, W]
|
||||
|
||||
# 加权空间损失
|
||||
weighted_loss = (per_pixel_loss * spatial_weights).mean()
|
||||
|
||||
# 类别平衡因子(抑制高频非法类别)
|
||||
class_balance = 1 + torch.var(illegal_probs.mean(dim=(0,2,3))) # [C_illegal]
|
||||
|
||||
return weighted_loss * class_balance.mean()
|
||||
|
||||
def entrance_spatial_constraint(
|
||||
pred_probs: torch.Tensor,
|
||||
arrow_class: int = 11,
|
||||
stair_class: int = 10,
|
||||
border_width: int = 1,
|
||||
lambda_arrow: float = 1.0,
|
||||
lambda_stair: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
入口空间约束损失函数
|
||||
|
||||
参数:
|
||||
pred_probs: 模型输出的概率分布 [B, C, H, W]
|
||||
arrow_class: 箭头入口类别索引
|
||||
stair_class: 楼梯入口类别索引
|
||||
border_width: 边缘区域宽度(默认1表示最外圈)
|
||||
lambda_arrow: 箭头约束权重
|
||||
lambda_stair: 楼梯约束权重
|
||||
|
||||
返回:
|
||||
loss: 标量损失值
|
||||
"""
|
||||
B, C, H, W = pred_probs.shape
|
||||
|
||||
##########################################
|
||||
# 1. 区域掩码生成
|
||||
##########################################
|
||||
# 生成边缘区域掩码 [H, W]
|
||||
edge_mask = torch.zeros((H, W), dtype=torch.bool, device=pred_probs.device)
|
||||
# 上下边缘
|
||||
edge_mask[:border_width, :] = True
|
||||
edge_mask[-border_width:, :] = True
|
||||
# 左右边缘(排除已标记的角落)
|
||||
edge_mask[:, :border_width] = True
|
||||
edge_mask[:, -border_width:] = True
|
||||
|
||||
# 生成中间区域掩码 [H, W]
|
||||
center_mask = ~edge_mask
|
||||
|
||||
##########################################
|
||||
# 2. 边缘区域约束(只能出现箭头)
|
||||
##########################################
|
||||
|
||||
# 抑制边缘出现楼梯的概率 [B, N_edge_pixels]
|
||||
edge_stair_probs = pred_probs[:, stair_class][:, edge_mask]
|
||||
edge_stair_penalty = F.relu(edge_stair_probs - 0.1).mean() # 允许10%以下
|
||||
|
||||
##########################################
|
||||
# 3. 中间区域约束(只能出现楼梯)
|
||||
##########################################
|
||||
|
||||
# 抑制中间出现箭头的概率 [B, N_center_pixels]
|
||||
center_arrow_probs = pred_probs[:, arrow_class][:, center_mask]
|
||||
center_arrow_penalty = F.relu(center_arrow_probs - 0.1).mean() # 允许10%以下
|
||||
|
||||
##########################################
|
||||
# 4. 综合损失
|
||||
##########################################
|
||||
total_loss = (
|
||||
lambda_arrow * edge_stair_penalty +
|
||||
lambda_stair * center_arrow_penalty
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
|
||||
"""Ginka Model 损失函数部分
|
||||
@ -353,7 +241,7 @@ class GinkaLoss(nn.Module):
|
||||
def forward(self, pred, target, target_vision_feat, target_topo_feat):
|
||||
# 地图结构损失
|
||||
class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred)
|
||||
entrance_loss = entrance_constraint_loss(pred) + entrance_spatial_constraint(pred)
|
||||
entrance_loss = entrance_constraint_loss(pred)
|
||||
count_loss = adaptive_count_loss(pred, target)
|
||||
|
||||
# 使用 Minamo Model 计算相似度
|
||||
|
||||
@ -3,6 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .output import GinkaOutput
|
||||
from .input import GinkaInput, GinkaFeatureInput
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
@ -12,23 +13,27 @@ class GinkaModel(nn.Module):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.unet = GinkaUNet(1, base_ch, out_ch, feat_dim)
|
||||
self.output = GinkaOutput(out_ch, out_ch, (13, 13))
|
||||
self.input = GinkaInput(feat_dim, 1, (32, 32))
|
||||
self.feat_enc = GinkaFeatureInput(feat_dim, 2, base_ch)
|
||||
self.unet = GinkaUNet(1, base_ch, base_ch, feat_dim)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, feat):
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
feat: 参考地图的特征向量
|
||||
x: 参考地图的特征向量
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
x = self.unet(x, feat)
|
||||
cond = x
|
||||
feat = self.feat_enc(x)
|
||||
x = self.input(x)
|
||||
x = self.unet(x, feat, cond)
|
||||
x = self.output(x)
|
||||
return x, F.softmax(x, dim=1)
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
x = torch.randn((1, 1, 32, 32)).cuda()
|
||||
feat = torch.randn((1, 1024)).cuda()
|
||||
|
||||
# 初始化模型
|
||||
@ -37,12 +42,14 @@ if __name__ == "__main__":
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(x, feat)
|
||||
output, output_softmax = model(feat)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: x={x.shape}, feat={feat.shape}")
|
||||
print(f"输入形状: feat={feat.shape}")
|
||||
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
|
||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -2,11 +2,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaOutput(nn.Module):
|
||||
def __init__(self, out_ch=32, base_ch=64, out_size=(13, 13)):
|
||||
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(base_ch, out_ch, 1)
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -2,29 +2,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class GinkaAdaIN(nn.Module):
|
||||
def __init__(self, num_features, condition_dim):
|
||||
"""
|
||||
自适应实例归一化 (AdaIN)
|
||||
参数:
|
||||
num_features: 归一化的通道数
|
||||
condition_dim: 条件输入的特征维度
|
||||
"""
|
||||
super(GinkaAdaIN, self).__init__()
|
||||
self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""
|
||||
x: [B, C, H, W] - 输入特征图
|
||||
condition: [B, condition_dim] - 需要注入的条件向量
|
||||
"""
|
||||
gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β
|
||||
gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状
|
||||
beta = beta.view(x.shape[0], x.shape[1], 1, 1)
|
||||
|
||||
x = F.instance_norm(x) # 标准化
|
||||
return gamma * x + beta # 进行变换
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
@ -40,29 +17,37 @@ class ConvBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class AdaINConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, feat_dim):
|
||||
class ConditionFusionBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数
|
||||
|
||||
def forward(self, x, cond_feat):
|
||||
return x + self.alpha * cond_feat # 残差融合
|
||||
|
||||
class FusionConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||||
self.fusion = ConditionFusionBlock()
|
||||
|
||||
def forward(self, x, feat):
|
||||
x = self.conv(x)
|
||||
x = self.adain(x, feat)
|
||||
x = self.fusion(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
def __init__(self, in_ch, out_ch, feat_dim):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||||
self.fusion = ConditionFusionBlock()
|
||||
|
||||
def forward(self, x, feat):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x = self.adain(x, feat)
|
||||
x = self.fusion(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaUpSample(nn.Module):
|
||||
@ -79,17 +64,17 @@ class GinkaUpSample(nn.Module):
|
||||
|
||||
class GinkaDecoder(nn.Module):
|
||||
"""解码器(上采样)部分"""
|
||||
def __init__(self, in_ch, out_ch, feat_dim):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.adain = GinkaAdaIN(out_ch, feat_dim)
|
||||
self.fusion = ConditionFusionBlock()
|
||||
|
||||
def forward(self, x, skip, feat):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
dec = self.upsample(x)
|
||||
x = torch.cat([dec, skip], dim=1)
|
||||
x = self.conv(x)
|
||||
x = self.adain(x, feat)
|
||||
x = self.fusion(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
@ -97,32 +82,32 @@ class GinkaUNet(nn.Module):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim)
|
||||
self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim)
|
||||
self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim)
|
||||
self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim)
|
||||
self.in_conv = FusionConvBlock(in_ch, base_ch)
|
||||
self.down1 = GinkaEncoder(base_ch, base_ch*2)
|
||||
self.down2 = GinkaEncoder(base_ch*2, base_ch*4)
|
||||
self.down3 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||
|
||||
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim)
|
||||
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16)
|
||||
|
||||
self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim)
|
||||
self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim)
|
||||
self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim)
|
||||
self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim)
|
||||
self.up1 = GinkaDecoder(base_ch*16, base_ch*8)
|
||||
self.up2 = GinkaDecoder(base_ch*8, base_ch*4)
|
||||
self.up3 = GinkaDecoder(base_ch*4, base_ch*2)
|
||||
self.up4 = GinkaDecoder(base_ch*2, base_ch)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(base_ch, out_ch, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, feat):
|
||||
x1 = self.in_conv(x, feat)
|
||||
x2 = self.down1(x1, feat)
|
||||
x3 = self.down2(x2, feat)
|
||||
x4 = self.down3(x3, feat)
|
||||
x5 = self.bottleneck(x4, feat)
|
||||
def forward(self, x, feat, cond):
|
||||
x1 = self.in_conv(x, feat[0])
|
||||
x2 = self.down1(x1, feat[1])
|
||||
x3 = self.down2(x2, feat[2])
|
||||
x4 = self.down3(x3, feat[3])
|
||||
x5 = self.bottleneck(x4, feat[4])
|
||||
|
||||
x = self.up1(x5, x4, feat)
|
||||
x = self.up2(x, x3, feat)
|
||||
x = self.up3(x, x2, feat)
|
||||
x = self.up4(x, x1, feat)
|
||||
x = self.up1(x5, x4, feat[3])
|
||||
x = self.up2(x, x3, feat[2])
|
||||
x = self.up3(x, x2, feat[1])
|
||||
x = self.up4(x, x1, feat[0])
|
||||
|
||||
return self.final(x)
|
||||
|
||||
@ -75,8 +75,7 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
|
||||
_, output_softmax = model(noise, feat_vec)
|
||||
_, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
@ -108,8 +107,7 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
|
||||
# 前向传播
|
||||
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
|
||||
output, output_softmax = model(noise, feat_vec)
|
||||
output, output_softmax = model(feat_vec)
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
|
||||
@ -108,8 +108,7 @@ def validate():
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
|
||||
output, output_softmax = model(noise, feat_vec)
|
||||
output, output_softmax = model(feat_vec)
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
for matrix in map_matrix[:].cpu():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user