mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 降低 UNet 参数量
This commit is contained in:
parent
eb0626ef88
commit
fbfa5f3141
@ -2,9 +2,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .sample import MapDownSample
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=256, base_ch=128, num_classes=32):
|
||||
def __init__(self, feat_dim=256, base_ch=32, num_classes=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
@ -13,6 +14,7 @@ class GinkaModel(nn.Module):
|
||||
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
||||
)
|
||||
self.unet = GinkaUNet(base_ch, num_classes)
|
||||
self.down_sample = MapDownSample(num_classes, num_classes)
|
||||
|
||||
def forward(self, feat):
|
||||
"""
|
||||
@ -24,6 +26,6 @@ class GinkaModel(nn.Module):
|
||||
x = self.fc(feat)
|
||||
x = x.view(-1, self.base_ch, 32, 32)
|
||||
x = self.unet(x)
|
||||
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
|
||||
x = self.down_sample(x)
|
||||
return F.softmax(x, dim=1)
|
||||
|
||||
15
ginka/model/sample.py
Normal file
15
ginka/model/sample.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MapDownSample(nn.Module):
|
||||
def __init__(self, in_ch=32, out_ch=32):
|
||||
super().__init__()
|
||||
self.down = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.down(x)
|
||||
return x
|
||||
@ -57,7 +57,7 @@ class GinkaBottleneck(nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=64, out_ch=32):
|
||||
def __init__(self, in_ch=32, out_ch=32):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -48,7 +48,7 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
|
||||
@ -66,10 +66,15 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"]
|
||||
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
|
||||
model.load_state_dict(state)
|
||||
model.to(device)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
print(f"Layer: {name}, Params: {param.numel()}")
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Total parameters: {total_params}")
|
||||
|
||||
minamo = MinamoModel(32)
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
|
||||
@ -13,9 +13,14 @@ def validate():
|
||||
model = MinamoModel(32)
|
||||
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
model.to(device)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
print(f"Layer: {name}, Params: {param.numel()}")
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Total parameters: {total_params}")
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = MinamoDataset("datasets/minamo-eval.json")
|
||||
val_dataset = MinamoDataset("minamo-eval.json")
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user