perf: 加深 UNet 结构

This commit is contained in:
unanmed 2025-03-19 16:25:20 +08:00
parent 09c63fedce
commit 452df38866
6 changed files with 41 additions and 63 deletions

View File

@ -322,7 +322,7 @@ class GinkaLoss(nn.Module):
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean()
minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean()
# print(
# minamo_loss.item(),

View File

@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from shared.attention import CBAM
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
@ -8,10 +9,11 @@ class GinkaEncoder(nn.Module):
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
# CBAM(out_channels),
nn.GELU()
)
self.pool = nn.MaxPool2d(2)
@ -29,7 +31,8 @@ class GinkaDecoder(nn.Module):
self.conv = nn.Sequential(
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
# CBAM(out_channels),
nn.GELU()
)
def forward(self, x, skip):
@ -44,10 +47,10 @@ class GinkaBottleneck(nn.Module):
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.GELU(),
)
def forward(self, x):
@ -60,11 +63,15 @@ class GinkaUNet(nn.Module):
super().__init__()
self.down1 = GinkaEncoder(in_ch, in_ch*2)
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
self.down3 = GinkaEncoder(in_ch*4, in_ch*8)
self.down4 = GinkaEncoder(in_ch*8, in_ch*16)
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16)
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
self.up2 = GinkaDecoder(in_ch*2, in_ch)
self.up1 = GinkaDecoder(in_ch*16, in_ch*8)
self.up2 = GinkaDecoder(in_ch*8, in_ch*4)
self.up3 = GinkaDecoder(in_ch*4, in_ch*2)
self.up4 = GinkaDecoder(in_ch*2, in_ch)
self.final = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1),
@ -74,10 +81,14 @@ class GinkaUNet(nn.Module):
def forward(self, x):
x_down1, skip1 = self.down1(x)
x_down2, skip2 = self.down2(x_down1)
x_down3, skip3 = self.down3(x_down2)
x_down4, skip4 = self.down4(x_down3)
x = self.bottleneck(x_down2)
x = self.bottleneck(x_down4)
x = self.up1(x, skip2) # 用 down2 的 skip
x = self.up2(x, skip1) # 用 down1 的 skip
x = self.up1(x, skip4) # 用 down2 的 skip
x = self.up2(x, skip3) # 用 down2 的 skip
x = self.up3(x, skip2) # 用 down1 的 skip
x = self.up4(x, skip1) # 用 down1 的 skip
return self.final(x)

View File

View File

@ -1,46 +0,0 @@
import torch
from ..model.model import DynamicPadConv, ConditionInjector, HybridUpsample
def test_dynamic_conv():
conv = DynamicPadConv(3, 64, stride=2)
# 测试奇数尺寸
x = torch.randn(1, 3, 15, 17)
out = conv(x)
assert out.shape == (1, 64, 8, 9), f"Got {out.shape}"
# 测试偶数尺寸
x = torch.randn(1, 3, 16, 16)
out = conv(x)
assert out.shape == (1, 64, 8, 8)
def test_condition_injector():
injector = ConditionInjector(128, 256)
x = torch.randn(2, 256, 16, 16)
cond = torch.randn(2, 128)
out = injector(x, cond)
assert out.shape == x.shape
assert not torch.allclose(out, x) # 确保条件起作用
def test_hybrid_upsample():
# 带跳跃连接的情况
upsample = HybridUpsample(256, 128, skip_ch=64)
x = torch.randn(2, 256, 8, 8)
skip = torch.randn(2, 64, 16, 16)
out = upsample(x, skip)
assert out.shape == (2, 128, 16, 16)
# 无跳跃连接的情况
upsample = HybridUpsample(256, 128)
out = upsample(x)
assert out.shape == (2, 128, 16, 16)
def test_all():
test_dynamic_conv()
print("✅ 动态卷积测试完毕")
test_condition_injector()
print("✅ 条件注入测试完毕")
test_hybrid_upsample()
print("✅ 混合上采样测试完毕")

View File

@ -1,4 +0,0 @@
from .test.model import test_all
if __name__ == "__main__":
test_all()

View File

@ -14,7 +14,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
epochs = 70
epochs = 150
def update_tau(epoch):
start_tau = 1.0
@ -27,9 +27,13 @@ def train():
model = GinkaModel()
model.to(device)
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
minamo.eval()
for param in minamo.parameters():
param.requires_grad = False
converter = DynamicGraphConverter().to(device)
# 准备数据集
@ -79,6 +83,18 @@ def train():
avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0
# for p in model.parameters():
# if p.grad is not None:
# param_norm = p.grad.detach().data.norm(2)
# total_norm += param_norm.item() ** 2
# total_norm = total_norm ** 0.5
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
# 学习率调整
scheduler.step()
@ -95,6 +111,7 @@ def train():
# 前向传播
output, output_softmax = model(feat_vec)
print(output_softmax[0])
# 计算损失
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)