feat: 降低 UNet 参数量

This commit is contained in:
unanmed 2025-03-22 12:31:01 +08:00
parent eb0626ef88
commit fbfa5f3141
6 changed files with 33 additions and 6 deletions

View File

@ -2,9 +2,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .sample import MapDownSample
class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=128, num_classes=32):
def __init__(self, feat_dim=256, base_ch=32, num_classes=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
@ -13,6 +14,7 @@ class GinkaModel(nn.Module):
nn.Linear(feat_dim, 32 * 32 * base_ch)
)
self.unet = GinkaUNet(base_ch, num_classes)
self.down_sample = MapDownSample(num_classes, num_classes)
def forward(self, feat):
"""
@ -24,6 +26,6 @@ class GinkaModel(nn.Module):
x = self.fc(feat)
x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
x = self.down_sample(x)
return F.softmax(x, dim=1)

15
ginka/model/sample.py Normal file
View File

@ -0,0 +1,15 @@
import torch
import torch.nn as nn
class MapDownSample(nn.Module):
def __init__(self, in_ch=32, out_ch=32):
super().__init__()
self.down = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0)
)
def forward(self, x):
x = self.down(x)
return x

View File

@ -57,7 +57,7 @@ class GinkaBottleneck(nn.Module):
return self.conv(x)
class GinkaUNet(nn.Module):
def __init__(self, in_ch=64, out_ch=32):
def __init__(self, in_ch=32, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()

View File

@ -48,7 +48,7 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)

View File

@ -66,10 +66,15 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"]
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)
for name, param in model.named_parameters():
print(f"Layer: {name}, Params: {param.numel()}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)

View File

@ -13,9 +13,14 @@ def validate():
model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device)
for name, param in model.named_parameters():
print(f"Layer: {name}, Params: {param.numel()}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
# 准备数据集
val_dataset = MinamoDataset("datasets/minamo-eval.json")
val_dataset = MinamoDataset("minamo-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,