diff --git a/ginka/model/model.py b/ginka/model/model.py index dac8576..5f8906e 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -2,9 +2,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from .unet import GinkaUNet +from .sample import MapDownSample class GinkaModel(nn.Module): - def __init__(self, feat_dim=256, base_ch=128, num_classes=32): + def __init__(self, feat_dim=256, base_ch=32, num_classes=32): """Ginka Model 模型定义部分 """ super().__init__() @@ -13,6 +14,7 @@ class GinkaModel(nn.Module): nn.Linear(feat_dim, 32 * 32 * base_ch) ) self.unet = GinkaUNet(base_ch, num_classes) + self.down_sample = MapDownSample(num_classes, num_classes) def forward(self, feat): """ @@ -24,6 +26,6 @@ class GinkaModel(nn.Module): x = self.fc(feat) x = x.view(-1, self.base_ch, 32, 32) x = self.unet(x) - x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False) + x = self.down_sample(x) return F.softmax(x, dim=1) \ No newline at end of file diff --git a/ginka/model/sample.py b/ginka/model/sample.py new file mode 100644 index 0000000..2af007d --- /dev/null +++ b/ginka/model/sample.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn + +class MapDownSample(nn.Module): + def __init__(self, in_ch=32, out_ch=32): + super().__init__() + self.down = nn.Sequential( + nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0) + ) + + def forward(self, x): + x = self.down(x) + return x diff --git a/ginka/model/unet.py b/ginka/model/unet.py index 39960c3..aaf8903 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -57,7 +57,7 @@ class GinkaBottleneck(nn.Module): return self.conv(x) class GinkaUNet(nn.Module): - def __init__(self, in_ch=64, out_ch=32): + def __init__(self, in_ch=32, out_ch=32): """Ginka Model UNet 部分 """ super().__init__() diff --git a/ginka/train.py b/ginka/train.py index 8e71222..4e069bb 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -48,7 +48,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-3) + optimizer = optim.AdamW(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = GinkaLoss(minamo) diff --git a/ginka/validate.py b/ginka/validate.py index c35e461..8a87bae 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -66,10 +66,15 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() - state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"] + state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"] model.load_state_dict(state) model.to(device) + for name, param in model.named_parameters(): + print(f"Layer: {name}, Params: {param.numel()}") + total_params = sum(p.numel() for p in model.parameters()) + print(f"Total parameters: {total_params}") + minamo = MinamoModel(32) minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) minamo.to(device) diff --git a/minamo/validate.py b/minamo/validate.py index 7b70f59..d8e7e3a 100644 --- a/minamo/validate.py +++ b/minamo/validate.py @@ -13,9 +13,14 @@ def validate(): model = MinamoModel(32) model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) model.to(device) + + for name, param in model.named_parameters(): + print(f"Layer: {name}, Params: {param.numel()}") + total_params = sum(p.numel() for p in model.parameters()) + print(f"Total parameters: {total_params}") # 准备数据集 - val_dataset = MinamoDataset("datasets/minamo-eval.json") + val_dataset = MinamoDataset("minamo-eval.json") val_loader = DataLoader( val_dataset, batch_size=32,