ginka-generator/ginka/model/sample.py
2025-03-15 22:26:31 +08:00

80 lines
2.8 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.ops as ops
class HybridUpsample(nn.Module):
"""自适应尺寸的混合上采样"""
def __init__(self, in_ch, out_ch, skip_ch=None):
super().__init__()
# 子像素卷积上采样
self.subpixel = nn.Sequential(
nn.Conv2d(in_ch, out_ch * 4, 3, padding=1),
nn.PixelShuffle(2) # 2倍上采样
)
# 跳跃连接处理
self.skip_conv = nn.Conv2d(skip_ch, out_ch, 1) if skip_ch else None
self.adaptive_pool = nn.AdaptiveAvgPool2d(None)
def forward(self, x, skip=None):
x = self.subpixel(x) # [B, out_ch, 2H, 2W]
if skip is not None and self.skip_conv:
# 自动对齐尺寸
if x.shape[-2:] != skip.shape[-2:]:
skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest')
# 融合特征
x = x + self.skip_conv(skip)
return x
class DiscreteAwareUpsample(nn.Module):
"""离散感知的智能上采样模块"""
def __init__(self, in_ch, out_ch, base_size=16):
super().__init__()
self.base_size = base_size
self.scale_factors = [2, 4, 8] # 支持放大倍数
# 可变形卷积增强几何感知
self.deform_conv = ops.DeformConv2d(in_ch, in_ch, kernel_size=3, padding=1)
# 多尺度特征融合
self.multi_scale = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_ch, in_ch//4, 1),
nn.Upsample(scale_factor=s, mode='nearest')
) for s in self.scale_factors
])
# 门控上采样机制
self.gate_conv = nn.Conv2d(in_ch*2, len(self.scale_factors)+1, 3, padding=1)
# 离散化输出层
self.final_conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch*4, 3, padding=1),
nn.PixelShuffle(2), # 亚像素卷积
nn.Conv2d(out_ch, out_ch, 3, padding=1)
)
def forward(self, x, target_size):
# 几何特征提取
deform_feat = self.deform_conv(x)
# 生成多尺度特征
scale_features = [f(deform_feat) for f in self.multi_scale]
# 动态门控选择
gate_map = F.softmax(self.gate_conv(torch.cat([x, deform_feat], dim=1)), dim=1)
# 加权融合多尺度特征
combined = sum(g * F.interpolate(f, size=target_size, mode='nearest')
for g, f in zip(gate_map.unbind(1), scale_features+[x]))
# 离散化上采样
out = self.final_conv(combined)
# 结构化约束(保持通道独立性)
return out.argmax(dim=1).unsqueeze(1).float() # 伪梯度保留