feat: 降低 UNet 参数量

This commit is contained in:
unanmed 2025-03-22 12:31:01 +08:00
parent eb0626ef88
commit fbfa5f3141
6 changed files with 33 additions and 6 deletions

View File

@ -2,9 +2,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .unet import GinkaUNet from .unet import GinkaUNet
from .sample import MapDownSample
class GinkaModel(nn.Module): class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=128, num_classes=32): def __init__(self, feat_dim=256, base_ch=32, num_classes=32):
"""Ginka Model 模型定义部分 """Ginka Model 模型定义部分
""" """
super().__init__() super().__init__()
@ -13,6 +14,7 @@ class GinkaModel(nn.Module):
nn.Linear(feat_dim, 32 * 32 * base_ch) nn.Linear(feat_dim, 32 * 32 * base_ch)
) )
self.unet = GinkaUNet(base_ch, num_classes) self.unet = GinkaUNet(base_ch, num_classes)
self.down_sample = MapDownSample(num_classes, num_classes)
def forward(self, feat): def forward(self, feat):
""" """
@ -24,6 +26,6 @@ class GinkaModel(nn.Module):
x = self.fc(feat) x = self.fc(feat)
x = x.view(-1, self.base_ch, 32, 32) x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x) x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False) x = self.down_sample(x)
return F.softmax(x, dim=1) return F.softmax(x, dim=1)

15
ginka/model/sample.py Normal file
View File

@ -0,0 +1,15 @@
import torch
import torch.nn as nn
class MapDownSample(nn.Module):
def __init__(self, in_ch=32, out_ch=32):
super().__init__()
self.down = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0)
)
def forward(self, x):
x = self.down(x)
return x

View File

@ -57,7 +57,7 @@ class GinkaBottleneck(nn.Module):
return self.conv(x) return self.conv(x)
class GinkaUNet(nn.Module): class GinkaUNet(nn.Module):
def __init__(self, in_ch=64, out_ch=32): def __init__(self, in_ch=32, out_ch=32):
"""Ginka Model UNet 部分 """Ginka Model UNet 部分
""" """
super().__init__() super().__init__()

View File

@ -48,7 +48,7 @@ def train():
) )
# 设定优化器与调度器 # 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3) optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo) criterion = GinkaLoss(minamo)

View File

@ -66,10 +66,15 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate(): def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel() model = GinkaModel()
state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"] state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
model.load_state_dict(state) model.load_state_dict(state)
model.to(device) model.to(device)
for name, param in model.named_parameters():
print(f"Layer: {name}, Params: {param.numel()}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
minamo = MinamoModel(32) minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device) minamo.to(device)

View File

@ -13,9 +13,14 @@ def validate():
model = MinamoModel(32) model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device) model.to(device)
for name, param in model.named_parameters():
print(f"Layer: {name}, Params: {param.numel()}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
# 准备数据集 # 准备数据集
val_dataset = MinamoDataset("datasets/minamo-eval.json") val_dataset = MinamoDataset("minamo-eval.json")
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=32, batch_size=32,