mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 21:57:52 +08:00
chore: 微调 rnn-vae
This commit is contained in:
parent
c8809b8ee7
commit
e66919d11e
@ -54,10 +54,10 @@ from shared.image import matrix_to_image_cv
|
||||
# 29. 入口,不区分楼梯和箭头
|
||||
|
||||
BATCH_SIZE = 128
|
||||
LATENT_DIM = 48
|
||||
KL_BETA = 0.1
|
||||
SELF_GATE = 0.5
|
||||
GATE_EPOCH = 5
|
||||
LATENT_DIM = 64
|
||||
KL_BETA = 0.01
|
||||
SELF_GATE = 0.3
|
||||
GATE_EPOCH = 10
|
||||
VAL_BATCH_DIVIDER = 128
|
||||
PROB_STEP = 0.05
|
||||
|
||||
@ -96,10 +96,10 @@ def train():
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
|
||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||
scheduler_ginka = VAEScheduler(
|
||||
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=2e-4, min_lr=1e-6
|
||||
optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6
|
||||
)
|
||||
|
||||
criterion = VAELoss()
|
||||
@ -129,6 +129,7 @@ def train():
|
||||
reco_loss_total = torch.Tensor([0]).to(device)
|
||||
kl_loss_total = torch.Tensor([0]).to(device)
|
||||
|
||||
vae.train()
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
@ -182,6 +183,7 @@ def train():
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
vae.eval()
|
||||
# 保存检查点
|
||||
torch.save({
|
||||
"model_state": vae.state_dict(),
|
||||
|
||||
@ -105,7 +105,13 @@ class DecoderInputFusion(nn.Module):
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.Dropout(0.2),
|
||||
nn.LayerNorm(d_model * 2),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(d_model),
|
||||
nn.GELU()
|
||||
)
|
||||
@ -138,6 +144,7 @@ class DecoderRNN(nn.Module):
|
||||
self.drop = nn.Dropout(0.2)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
|
||||
@ -168,8 +175,10 @@ class VAEDecoder(nn.Module):
|
||||
# 模型结构
|
||||
self.map_vec_fc = nn.Sequential(
|
||||
nn.Linear(map_vec_dim, 128),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(128),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(128, 256)
|
||||
)
|
||||
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
|
||||
@ -227,22 +236,22 @@ class VAEDecoder(nn.Module):
|
||||
return output_logits.permute(0, 3, 1, 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||
map_vec = torch.rand(1, 32).to(device)
|
||||
|
||||
# 初始化模型
|
||||
model = VAEDecoder("cpu").to(device)
|
||||
model = VAEDecoder(device).to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
print_memory(device, "初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
fake_logits = model(map_vec, input, 0)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
print_memory(device, "前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输出形状: fake_logits={fake_logits.shape}")
|
||||
|
||||
@ -38,9 +38,10 @@ class EncoderGRU(nn.Module):
|
||||
|
||||
# GRU
|
||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
||||
self.drop = nn.Dropout(0.1)
|
||||
self.drop = nn.Dropout(0.2)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
|
||||
@ -62,13 +63,14 @@ class EncoderFusion(nn.Module):
|
||||
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True
|
||||
d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True
|
||||
),
|
||||
num_layers=2
|
||||
num_layers=3
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(d_model * 2),
|
||||
nn.GELU()
|
||||
)
|
||||
@ -91,11 +93,20 @@ class VAEEncoder(nn.Module):
|
||||
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
||||
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
||||
self.fusion = EncoderFusion(256)
|
||||
self.fc = nn.Sequential(
|
||||
self.fc_mu = nn.Sequential(
|
||||
nn.Linear(512, 512),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Linear(512, latent_dim)
|
||||
)
|
||||
self.fc_logvar = nn.Sequential(
|
||||
nn.Linear(512, 512),
|
||||
nn.Dropout(0.1),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Linear(512, latent_dim)
|
||||
)
|
||||
self.fc_mu = nn.Linear(512, latent_dim)
|
||||
self.fc_logvar = nn.Linear(512, latent_dim)
|
||||
|
||||
self.col_list = []
|
||||
self.row_list = []
|
||||
@ -126,21 +137,21 @@ class VAEEncoder(nn.Module):
|
||||
return mu, logvar
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||
|
||||
# 初始化模型
|
||||
model = VAEEncoder(device).to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
print_memory(device, "初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
mu, logvar = model(input)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
print_memory(device, "前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user