chore: 微调 rnn-vae

This commit is contained in:
unanmed 2026-03-08 16:05:07 +08:00
parent c8809b8ee7
commit e66919d11e
3 changed files with 41 additions and 19 deletions

View File

@ -54,10 +54,10 @@ from shared.image import matrix_to_image_cv
# 29. 入口,不区分楼梯和箭头
BATCH_SIZE = 128
LATENT_DIM = 48
KL_BETA = 0.1
SELF_GATE = 0.5
GATE_EPOCH = 5
LATENT_DIM = 64
KL_BETA = 0.01
SELF_GATE = 0.3
GATE_EPOCH = 10
VAL_BATCH_DIVIDER = 128
PROB_STEP = 0.05
@ -96,10 +96,10 @@ def train():
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
scheduler_ginka = VAEScheduler(
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=2e-4, min_lr=1e-6
optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6
)
criterion = VAELoss()
@ -129,6 +129,7 @@ def train():
reco_loss_total = torch.Tensor([0]).to(device)
kl_loss_total = torch.Tensor([0]).to(device)
vae.train()
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
@ -182,6 +183,7 @@ def train():
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
vae.eval()
# 保存检查点
torch.save({
"model_state": vae.state_dict(),

View File

@ -105,7 +105,13 @@ class DecoderInputFusion(nn.Module):
)
self.norm = nn.LayerNorm(d_model)
self.fusion = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.Dropout(0.2),
nn.LayerNorm(d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(0.1),
nn.LayerNorm(d_model),
nn.GELU()
)
@ -138,6 +144,7 @@ class DecoderRNN(nn.Module):
self.drop = nn.Dropout(0.2)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Dropout(0.1),
nn.LayerNorm(hidden_dim),
nn.GELU(),
@ -168,8 +175,10 @@ class VAEDecoder(nn.Module):
# 模型结构
self.map_vec_fc = nn.Sequential(
nn.Linear(map_vec_dim, 128),
nn.Dropout(0.1),
nn.LayerNorm(128),
nn.GELU(),
nn.Linear(128, 256)
)
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
@ -227,22 +236,22 @@ class VAEDecoder(nn.Module):
return output_logits.permute(0, 3, 1, 2)
if __name__ == "__main__":
device = torch.device("cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
map_vec = torch.rand(1, 32).to(device)
# 初始化模型
model = VAEDecoder("cpu").to(device)
model = VAEDecoder(device).to(device)
print_memory("初始化后")
print_memory(device, "初始化后")
# 前向传播
start = time.perf_counter()
fake_logits = model(map_vec, input, 0)
end = time.perf_counter()
print_memory("前向传播后")
print_memory(device, "前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: fake_logits={fake_logits.shape}")

View File

@ -38,9 +38,10 @@ class EncoderGRU(nn.Module):
# GRU
self.gru = nn.GRUCell(input_dim, hidden_dim)
self.drop = nn.Dropout(0.1)
self.drop = nn.Dropout(0.2)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Dropout(0.1),
nn.LayerNorm(hidden_dim),
nn.GELU(),
@ -62,13 +63,14 @@ class EncoderFusion(nn.Module):
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True
d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True
),
num_layers=2
num_layers=3
)
self.norm = nn.LayerNorm(d_model)
self.fc = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.Dropout(0.1),
nn.LayerNorm(d_model * 2),
nn.GELU()
)
@ -91,11 +93,20 @@ class VAEEncoder(nn.Module):
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
self.fusion = EncoderFusion(256)
self.fc = nn.Sequential(
self.fc_mu = nn.Sequential(
nn.Linear(512, 512),
nn.Dropout(0.1),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, latent_dim)
)
self.fc_logvar = nn.Sequential(
nn.Linear(512, 512),
nn.Dropout(0.1),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, latent_dim)
)
self.fc_mu = nn.Linear(512, latent_dim)
self.fc_logvar = nn.Linear(512, latent_dim)
self.col_list = []
self.row_list = []
@ -126,21 +137,21 @@ class VAEEncoder(nn.Module):
return mu, logvar
if __name__ == "__main__":
device = torch.device("cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
# 初始化模型
model = VAEEncoder(device).to(device)
print_memory("初始化后")
print_memory(device, "初始化后")
# 前向传播
start = time.perf_counter()
mu, logvar = model(input)
end = time.perf_counter()
print_memory("前向传播后")
print_memory(device, "前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")