mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 加深 UNet 结构
This commit is contained in:
parent
09c63fedce
commit
452df38866
@ -322,7 +322,7 @@ class GinkaLoss(nn.Module):
|
||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
||||
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
|
||||
minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean()
|
||||
minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean()
|
||||
|
||||
# print(
|
||||
# minamo_loss.item(),
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shared.attention import CBAM
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
@ -8,10 +9,11 @@ class GinkaEncoder(nn.Module):
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
# CBAM(out_channels),
|
||||
nn.GELU()
|
||||
)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
@ -29,7 +31,8 @@ class GinkaDecoder(nn.Module):
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
# CBAM(out_channels),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x, skip):
|
||||
@ -44,10 +47,10 @@ class GinkaBottleneck(nn.Module):
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -60,11 +63,15 @@ class GinkaUNet(nn.Module):
|
||||
super().__init__()
|
||||
self.down1 = GinkaEncoder(in_ch, in_ch*2)
|
||||
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
|
||||
self.down3 = GinkaEncoder(in_ch*4, in_ch*8)
|
||||
self.down4 = GinkaEncoder(in_ch*8, in_ch*16)
|
||||
|
||||
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
|
||||
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16)
|
||||
|
||||
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
|
||||
self.up2 = GinkaDecoder(in_ch*2, in_ch)
|
||||
self.up1 = GinkaDecoder(in_ch*16, in_ch*8)
|
||||
self.up2 = GinkaDecoder(in_ch*8, in_ch*4)
|
||||
self.up3 = GinkaDecoder(in_ch*4, in_ch*2)
|
||||
self.up4 = GinkaDecoder(in_ch*2, in_ch)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 1),
|
||||
@ -74,10 +81,14 @@ class GinkaUNet(nn.Module):
|
||||
def forward(self, x):
|
||||
x_down1, skip1 = self.down1(x)
|
||||
x_down2, skip2 = self.down2(x_down1)
|
||||
x_down3, skip3 = self.down3(x_down2)
|
||||
x_down4, skip4 = self.down4(x_down3)
|
||||
|
||||
x = self.bottleneck(x_down2)
|
||||
x = self.bottleneck(x_down4)
|
||||
|
||||
x = self.up1(x, skip2) # 用 down2 的 skip
|
||||
x = self.up2(x, skip1) # 用 down1 的 skip
|
||||
x = self.up1(x, skip4) # 用 down2 的 skip
|
||||
x = self.up2(x, skip3) # 用 down2 的 skip
|
||||
x = self.up3(x, skip2) # 用 down1 的 skip
|
||||
x = self.up4(x, skip1) # 用 down1 的 skip
|
||||
|
||||
return self.final(x)
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
import torch
|
||||
from ..model.model import DynamicPadConv, ConditionInjector, HybridUpsample
|
||||
|
||||
def test_dynamic_conv():
|
||||
conv = DynamicPadConv(3, 64, stride=2)
|
||||
|
||||
# 测试奇数尺寸
|
||||
x = torch.randn(1, 3, 15, 17)
|
||||
out = conv(x)
|
||||
assert out.shape == (1, 64, 8, 9), f"Got {out.shape}"
|
||||
|
||||
# 测试偶数尺寸
|
||||
x = torch.randn(1, 3, 16, 16)
|
||||
out = conv(x)
|
||||
assert out.shape == (1, 64, 8, 8)
|
||||
|
||||
def test_condition_injector():
|
||||
injector = ConditionInjector(128, 256)
|
||||
x = torch.randn(2, 256, 16, 16)
|
||||
cond = torch.randn(2, 128)
|
||||
|
||||
out = injector(x, cond)
|
||||
assert out.shape == x.shape
|
||||
assert not torch.allclose(out, x) # 确保条件起作用
|
||||
|
||||
def test_hybrid_upsample():
|
||||
# 带跳跃连接的情况
|
||||
upsample = HybridUpsample(256, 128, skip_ch=64)
|
||||
x = torch.randn(2, 256, 8, 8)
|
||||
skip = torch.randn(2, 64, 16, 16)
|
||||
out = upsample(x, skip)
|
||||
assert out.shape == (2, 128, 16, 16)
|
||||
|
||||
# 无跳跃连接的情况
|
||||
upsample = HybridUpsample(256, 128)
|
||||
out = upsample(x)
|
||||
assert out.shape == (2, 128, 16, 16)
|
||||
|
||||
def test_all():
|
||||
test_dynamic_conv()
|
||||
print("✅ 动态卷积测试完毕")
|
||||
test_condition_injector()
|
||||
print("✅ 条件注入测试完毕")
|
||||
test_hybrid_upsample()
|
||||
print("✅ 混合上采样测试完毕")
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
from .test.model import test_all
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all()
|
||||
@ -14,7 +14,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 70
|
||||
epochs = 150
|
||||
|
||||
def update_tau(epoch):
|
||||
start_tau = 1.0
|
||||
@ -27,9 +27,13 @@ def train():
|
||||
model = GinkaModel()
|
||||
model.to(device)
|
||||
minamo = MinamoModel(32)
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
for param in minamo.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
converter = DynamicGraphConverter().to(device)
|
||||
|
||||
# 准备数据集
|
||||
@ -79,6 +83,18 @@ def train():
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
# if p.grad is not None:
|
||||
# param_norm = p.grad.detach().data.norm(2)
|
||||
# total_norm += param_norm.item() ** 2
|
||||
# total_norm = total_norm ** 0.5
|
||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler.step()
|
||||
|
||||
@ -95,6 +111,7 @@ def train():
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
print(output_softmax[0])
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user