diff --git a/ginka/model/loss.py b/ginka/model/loss.py index ea1e495..37e23ed 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -322,7 +322,7 @@ class GinkaLoss(nn.Module): vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1) topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1) minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim - minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean() + minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean() # print( # minamo_loss.item(), diff --git a/ginka/model/unet.py b/ginka/model/unet.py index ec0f2d3..39960c3 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from shared.attention import CBAM class GinkaEncoder(nn.Module): """编码器(下采样)部分""" @@ -8,10 +9,11 @@ class GinkaEncoder(nn.Module): self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), - nn.ReLU(), + nn.GELU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), - nn.ReLU() + # CBAM(out_channels), + nn.GELU() ) self.pool = nn.MaxPool2d(2) @@ -29,7 +31,8 @@ class GinkaDecoder(nn.Module): self.conv = nn.Sequential( nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), - nn.ReLU() + # CBAM(out_channels), + nn.GELU() ) def forward(self, x, skip): @@ -44,10 +47,10 @@ class GinkaBottleneck(nn.Module): self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), - nn.ReLU(), + nn.GELU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), - nn.ReLU(), + nn.GELU(), ) def forward(self, x): @@ -60,11 +63,15 @@ class GinkaUNet(nn.Module): super().__init__() self.down1 = GinkaEncoder(in_ch, in_ch*2) self.down2 = GinkaEncoder(in_ch*2, in_ch*4) + self.down3 = GinkaEncoder(in_ch*4, in_ch*8) + self.down4 = GinkaEncoder(in_ch*8, in_ch*16) - self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4) + self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16) - self.up1 = GinkaDecoder(in_ch*4, in_ch*2) - self.up2 = GinkaDecoder(in_ch*2, in_ch) + self.up1 = GinkaDecoder(in_ch*16, in_ch*8) + self.up2 = GinkaDecoder(in_ch*8, in_ch*4) + self.up3 = GinkaDecoder(in_ch*4, in_ch*2) + self.up4 = GinkaDecoder(in_ch*2, in_ch) self.final = nn.Sequential( nn.Conv2d(in_ch, out_ch, 1), @@ -74,10 +81,14 @@ class GinkaUNet(nn.Module): def forward(self, x): x_down1, skip1 = self.down1(x) x_down2, skip2 = self.down2(x_down1) + x_down3, skip3 = self.down3(x_down2) + x_down4, skip4 = self.down4(x_down3) - x = self.bottleneck(x_down2) + x = self.bottleneck(x_down4) - x = self.up1(x, skip2) # 用 down2 的 skip - x = self.up2(x, skip1) # 用 down1 的 skip + x = self.up1(x, skip4) # 用 down2 的 skip + x = self.up2(x, skip3) # 用 down2 的 skip + x = self.up3(x, skip2) # 用 down1 的 skip + x = self.up4(x, skip1) # 用 down1 的 skip return self.final(x) diff --git a/ginka/test/__init__.py b/ginka/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ginka/test/model.py b/ginka/test/model.py deleted file mode 100644 index 05dfa75..0000000 --- a/ginka/test/model.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -from ..model.model import DynamicPadConv, ConditionInjector, HybridUpsample - -def test_dynamic_conv(): - conv = DynamicPadConv(3, 64, stride=2) - - # 测试奇数尺寸 - x = torch.randn(1, 3, 15, 17) - out = conv(x) - assert out.shape == (1, 64, 8, 9), f"Got {out.shape}" - - # 测试偶数尺寸 - x = torch.randn(1, 3, 16, 16) - out = conv(x) - assert out.shape == (1, 64, 8, 8) - -def test_condition_injector(): - injector = ConditionInjector(128, 256) - x = torch.randn(2, 256, 16, 16) - cond = torch.randn(2, 128) - - out = injector(x, cond) - assert out.shape == x.shape - assert not torch.allclose(out, x) # 确保条件起作用 - -def test_hybrid_upsample(): - # 带跳跃连接的情况 - upsample = HybridUpsample(256, 128, skip_ch=64) - x = torch.randn(2, 256, 8, 8) - skip = torch.randn(2, 64, 16, 16) - out = upsample(x, skip) - assert out.shape == (2, 128, 16, 16) - - # 无跳跃连接的情况 - upsample = HybridUpsample(256, 128) - out = upsample(x) - assert out.shape == (2, 128, 16, 16) - -def test_all(): - test_dynamic_conv() - print("✅ 动态卷积测试完毕") - test_condition_injector() - print("✅ 条件注入测试完毕") - test_hybrid_upsample() - print("✅ 混合上采样测试完毕") - \ No newline at end of file diff --git a/ginka/test_model.py b/ginka/test_model.py deleted file mode 100644 index 2843c18..0000000 --- a/ginka/test_model.py +++ /dev/null @@ -1,4 +0,0 @@ -from .test.model import test_all - -if __name__ == "__main__": - test_all() \ No newline at end of file diff --git a/ginka/train.py b/ginka/train.py index eb3559d..17b14a5 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -14,7 +14,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/ginka_checkpoint", exist_ok=True) -epochs = 70 +epochs = 150 def update_tau(epoch): start_tau = 1.0 @@ -27,9 +27,13 @@ def train(): model = GinkaModel() model.to(device) minamo = MinamoModel(32) + minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) minamo.to(device) minamo.eval() + for param in minamo.parameters(): + param.requires_grad = False + converter = DynamicGraphConverter().to(device) # 准备数据集 @@ -79,6 +83,18 @@ def train(): avg_loss = total_loss / len(dataloader) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") + # total_norm = 0 + # for p in model.parameters(): + # if p.grad is not None: + # param_norm = p.grad.detach().data.norm(2) + # total_norm += param_norm.item() ** 2 + # total_norm = total_norm ** 0.5 + # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 + + # for name, param in model.named_parameters(): + # if param.grad is not None: + # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") + # 学习率调整 scheduler.step() @@ -95,6 +111,7 @@ def train(): # 前向传播 output, output_softmax = model(feat_vec) + print(output_softmax[0]) # 计算损失 loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)