mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
chore: 微调 rnn-vae
This commit is contained in:
parent
c8809b8ee7
commit
e66919d11e
@ -54,10 +54,10 @@ from shared.image import matrix_to_image_cv
|
|||||||
# 29. 入口,不区分楼梯和箭头
|
# 29. 入口,不区分楼梯和箭头
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
LATENT_DIM = 48
|
LATENT_DIM = 64
|
||||||
KL_BETA = 0.1
|
KL_BETA = 0.01
|
||||||
SELF_GATE = 0.5
|
SELF_GATE = 0.3
|
||||||
GATE_EPOCH = 5
|
GATE_EPOCH = 10
|
||||||
VAL_BATCH_DIVIDER = 128
|
VAL_BATCH_DIVIDER = 128
|
||||||
PROB_STEP = 0.05
|
PROB_STEP = 0.05
|
||||||
|
|
||||||
@ -96,10 +96,10 @@ def train():
|
|||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
|
||||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||||
scheduler_ginka = VAEScheduler(
|
scheduler_ginka = VAEScheduler(
|
||||||
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=2e-4, min_lr=1e-6
|
optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
@ -129,6 +129,7 @@ def train():
|
|||||||
reco_loss_total = torch.Tensor([0]).to(device)
|
reco_loss_total = torch.Tensor([0]).to(device)
|
||||||
kl_loss_total = torch.Tensor([0]).to(device)
|
kl_loss_total = torch.Tensor([0]).to(device)
|
||||||
|
|
||||||
|
vae.train()
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
@ -182,6 +183,7 @@ def train():
|
|||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
# 每若干轮输出一次图片,并保存检查点
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
if (epoch + 1) % args.checkpoint == 0:
|
||||||
|
vae.eval()
|
||||||
# 保存检查点
|
# 保存检查点
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": vae.state_dict(),
|
"model_state": vae.state_dict(),
|
||||||
|
|||||||
@ -105,7 +105,13 @@ class DecoderInputFusion(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = nn.LayerNorm(d_model)
|
self.norm = nn.LayerNorm(d_model)
|
||||||
self.fusion = nn.Sequential(
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Linear(d_model * 2, d_model * 2),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.LayerNorm(d_model * 2),
|
||||||
|
nn.GELU(),
|
||||||
|
|
||||||
nn.Linear(d_model * 2, d_model),
|
nn.Linear(d_model * 2, d_model),
|
||||||
|
nn.Dropout(0.1),
|
||||||
nn.LayerNorm(d_model),
|
nn.LayerNorm(d_model),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
)
|
)
|
||||||
@ -138,6 +144,7 @@ class DecoderRNN(nn.Module):
|
|||||||
self.drop = nn.Dropout(0.2)
|
self.drop = nn.Dropout(0.2)
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.Dropout(0.1),
|
||||||
nn.LayerNorm(hidden_dim),
|
nn.LayerNorm(hidden_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
|
||||||
@ -168,8 +175,10 @@ class VAEDecoder(nn.Module):
|
|||||||
# 模型结构
|
# 模型结构
|
||||||
self.map_vec_fc = nn.Sequential(
|
self.map_vec_fc = nn.Sequential(
|
||||||
nn.Linear(map_vec_dim, 128),
|
nn.Linear(map_vec_dim, 128),
|
||||||
|
nn.Dropout(0.1),
|
||||||
nn.LayerNorm(128),
|
nn.LayerNorm(128),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.Linear(128, 256)
|
nn.Linear(128, 256)
|
||||||
)
|
)
|
||||||
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
|
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
|
||||||
@ -227,22 +236,22 @@ class VAEDecoder(nn.Module):
|
|||||||
return output_logits.permute(0, 3, 1, 2)
|
return output_logits.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||||
map_vec = torch.rand(1, 32).to(device)
|
map_vec = torch.rand(1, 32).to(device)
|
||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = VAEDecoder("cpu").to(device)
|
model = VAEDecoder(device).to(device)
|
||||||
|
|
||||||
print_memory("初始化后")
|
print_memory(device, "初始化后")
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
fake_logits = model(map_vec, input, 0)
|
fake_logits = model(map_vec, input, 0)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
||||||
print_memory("前向传播后")
|
print_memory(device, "前向传播后")
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
print(f"推理耗时: {end - start}")
|
||||||
print(f"输出形状: fake_logits={fake_logits.shape}")
|
print(f"输出形状: fake_logits={fake_logits.shape}")
|
||||||
|
|||||||
@ -38,9 +38,10 @@ class EncoderGRU(nn.Module):
|
|||||||
|
|
||||||
# GRU
|
# GRU
|
||||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
||||||
self.drop = nn.Dropout(0.1)
|
self.drop = nn.Dropout(0.2)
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.Dropout(0.1),
|
||||||
nn.LayerNorm(hidden_dim),
|
nn.LayerNorm(hidden_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
|
||||||
@ -62,13 +63,14 @@ class EncoderFusion(nn.Module):
|
|||||||
|
|
||||||
self.transformer = nn.TransformerEncoder(
|
self.transformer = nn.TransformerEncoder(
|
||||||
nn.TransformerEncoderLayer(
|
nn.TransformerEncoderLayer(
|
||||||
d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True
|
d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True
|
||||||
),
|
),
|
||||||
num_layers=2
|
num_layers=3
|
||||||
)
|
)
|
||||||
self.norm = nn.LayerNorm(d_model)
|
self.norm = nn.LayerNorm(d_model)
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(d_model * 2, d_model * 2),
|
nn.Linear(d_model * 2, d_model * 2),
|
||||||
|
nn.Dropout(0.1),
|
||||||
nn.LayerNorm(d_model * 2),
|
nn.LayerNorm(d_model * 2),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
)
|
)
|
||||||
@ -91,11 +93,20 @@ class VAEEncoder(nn.Module):
|
|||||||
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
||||||
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
||||||
self.fusion = EncoderFusion(256)
|
self.fusion = EncoderFusion(256)
|
||||||
self.fc = nn.Sequential(
|
self.fc_mu = nn.Sequential(
|
||||||
|
nn.Linear(512, 512),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(512, latent_dim)
|
||||||
|
)
|
||||||
|
self.fc_logvar = nn.Sequential(
|
||||||
|
nn.Linear(512, 512),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.LayerNorm(512),
|
||||||
|
nn.GELU(),
|
||||||
nn.Linear(512, latent_dim)
|
nn.Linear(512, latent_dim)
|
||||||
)
|
)
|
||||||
self.fc_mu = nn.Linear(512, latent_dim)
|
|
||||||
self.fc_logvar = nn.Linear(512, latent_dim)
|
|
||||||
|
|
||||||
self.col_list = []
|
self.col_list = []
|
||||||
self.row_list = []
|
self.row_list = []
|
||||||
@ -126,21 +137,21 @@ class VAEEncoder(nn.Module):
|
|||||||
return mu, logvar
|
return mu, logvar
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = VAEEncoder(device).to(device)
|
model = VAEEncoder(device).to(device)
|
||||||
|
|
||||||
print_memory("初始化后")
|
print_memory(device, "初始化后")
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
mu, logvar = model(input)
|
mu, logvar = model(input)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
||||||
print_memory("前向传播后")
|
print_memory(device, "前向传播后")
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
print(f"推理耗时: {end - start}")
|
||||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user