mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 13:21:09 +08:00
feat: vae 模型
This commit is contained in:
parent
62ec7daa16
commit
79cf3ab226
225
ginka/train_vae.py
Normal file
225
ginka/train_vae.py
Normal file
@ -0,0 +1,225 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .vae_rnn.vae import GinkaVAE
|
||||
from .vae_rnn.loss import VAELoss
|
||||
from .dataset import GinkaRNNDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# 手工标注标签定义(暂时不用):
|
||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
||||
|
||||
# 自动标注标签定义(暂时不用):
|
||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
||||
|
||||
# 标量值定义:
|
||||
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
|
||||
# 1. 墙体密度,墙壁/地图面积
|
||||
# 2. 装饰密度,装饰数量/地图面积
|
||||
# 3. 门密度,门数量/地图面积
|
||||
# 4. 怪物密度,怪物数量/地图面积
|
||||
# 5. 资源密度,资源数量/地图面积
|
||||
# 6. 宝石密度,宝石数量/地图面积
|
||||
# 7. 血瓶密度,血瓶数量/地图面积
|
||||
# 8. 钥匙密度,钥匙数量/地图面积
|
||||
# 9. 道具密度,道具数量/地图面积
|
||||
# 10. 入口数量
|
||||
# 11. 机关门数量
|
||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
||||
|
||||
# 图块定义:
|
||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
||||
# 10-12. 三种等级的红宝石
|
||||
# 13-15. 三种等级的蓝宝石
|
||||
# 16-18. 三种等级的绿宝石
|
||||
# 19-22. 四种等级的血瓶
|
||||
# 23-25. 三种等级的道具
|
||||
# 26-28. 三种等级的怪物
|
||||
# 29. 入口,不区分楼梯和箭头
|
||||
|
||||
BATCH_SIZE = 32
|
||||
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/vae", exist_ok=True)
|
||||
os.makedirs("result/ginka_vae_img", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
||||
progress = epoch / max_epoch
|
||||
return max(1.2 * progress - 0.2, 0)
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
vae = GinkaVAE(device).to(device)
|
||||
|
||||
dataset = GinkaRNNDataset(args.train, device)
|
||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||
|
||||
criterion = VAELoss()
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if args.resume:
|
||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||
|
||||
vae.load_state_dict(data_ginka["model_state"], strict=False)
|
||||
|
||||
if args.load_optim:
|
||||
if data_ginka.get("optim_state") is not None:
|
||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
||||
loss_total = torch.Tensor([0]).to(device)
|
||||
reco_loss_total = torch.Tensor([0]).to(device)
|
||||
kl_loss_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
||||
|
||||
loss.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total += loss.detach()
|
||||
reco_loss_total += reco_loss.detach()
|
||||
kl_loss_total += kl_loss.detach()
|
||||
|
||||
avg_loss = loss_total.item() / len(dataloader)
|
||||
avg_reco_loss = reco_loss_total.item() / len(dataloader)
|
||||
avg_kl_loss = kl_loss_total.item() / len(dataloader)
|
||||
tqdm.write(
|
||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco Loss: {avg_reco_loss:.6f} | " +
|
||||
f"KL Loss: {avg_kl_loss:.6f} | LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
)
|
||||
|
||||
scheduler_ginka.step()
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
# 保存检查点
|
||||
torch.save({
|
||||
"model_state": vae.state_dict(),
|
||||
"optim_state": optimizer_ginka.state_dict(),
|
||||
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
||||
|
||||
val_loss_total = torch.Tensor([0]).to(device)
|
||||
reco_loss_total = torch.Tensor([0]).to(device)
|
||||
kl_loss_total = torch.Tensor([0]).to(device)
|
||||
with torch.no_grad():
|
||||
idx = 0
|
||||
gap = 5
|
||||
color = (255, 255, 255) # 白色
|
||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||
# 地图重建展示
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
||||
val_loss_total += loss.detach()
|
||||
reco_loss_total += reco_loss.detach()
|
||||
kl_loss_total += kl_loss.detach()
|
||||
|
||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
real_map = target_map.cpu().numpy()
|
||||
real_img = matrix_to_image_cv(real_map[0], tile_dict)
|
||||
img = np.block([[real_img], [vline], [fake_img]])
|
||||
cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img)
|
||||
|
||||
idx += 1
|
||||
|
||||
# 随机采样
|
||||
for i in range(0, 8):
|
||||
z = torch.randn(1, 32).to(device)
|
||||
|
||||
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
|
||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
|
||||
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
||||
|
||||
# 插值
|
||||
map1 = torch.LongTensor(dataset_val.data[0]["map"]).to(device).reshape(1, 13, 13)
|
||||
map2 = torch.LongTensor(dataset_val.data[1]["map"]).to(device).reshape(1, 13, 13)
|
||||
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
|
||||
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
|
||||
mu1, logvar1 = vae.encoder(map1_onehot)
|
||||
mu2, logvar2 = vae.encoder(map2_onehot)
|
||||
z1 = vae.reparameterize(mu1, logvar1)
|
||||
z2 = vae.reparameterize(mu2, logvar2)
|
||||
real_img1 = matrix_to_image_cv(map1[0], tile_dict)
|
||||
real_img2 = matrix_to_image_cv(map2[0], tile_dict)
|
||||
for t in torch.linspace(0, 1, 8):
|
||||
z = z1 * (1 - t / 8) + z2 * t / 8
|
||||
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
|
||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
|
||||
|
||||
cv2.imwrite(f"result/ginka_vae_img/{t}_linspace.png", img)
|
||||
|
||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
avg_reco_loss = reco_loss_total.item() / len(dataloader_val)
|
||||
avg_kl_loss = kl_loss_total.item() / len(dataloader_val)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
|
||||
f"Loss: {avg_loss_val:.6f} | Reco Loss: {avg_reco_loss:.6f} | " +
|
||||
f"KL Loss: {avg_kl_loss:.6f}"
|
||||
)
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": vae.state_dict(),
|
||||
}, f"result/ginka_rnn.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
7
ginka/utils.py
Normal file
7
ginka/utils.py
Normal file
@ -0,0 +1,7 @@
|
||||
import torch
|
||||
|
||||
def print_memory(device, tag=""):
|
||||
if torch.cuda.is_available():
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB")
|
||||
else:
|
||||
print("当前设备不支持 cuda.")
|
||||
236
ginka/vae_rnn/decoder.py
Normal file
236
ginka/vae_rnn/decoder.py
Normal file
@ -0,0 +1,236 @@
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..utils import print_memory
|
||||
|
||||
class GinkaMapPatch(nn.Module):
|
||||
def __init__(self, tile_classes=32, width=13, height=13):
|
||||
super().__init__()
|
||||
|
||||
# 地图局部卷积,用于捕获局部结构信息
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.tile_classes = 32
|
||||
|
||||
self.patch_cnn = nn.Sequential(
|
||||
nn.Conv2d(tile_classes + 1, 64, 3, padding=1),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(64, 128, 3),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc = nn.Linear(128 * 3 * 3, 256)
|
||||
|
||||
def forward(self, map: torch.Tensor, x: int, y: int):
|
||||
"""
|
||||
map: [B, H, W]
|
||||
"""
|
||||
B, H, W = map.shape
|
||||
mask = torch.zeros([B, 5, 5]).to(map.device)
|
||||
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
|
||||
left = x - 2 if x >= 2 else 0
|
||||
right = x + 3 if x < self.width - 2 else self.width
|
||||
top = y - 4 if y >= 4 else 0
|
||||
bottom = y + 1
|
||||
|
||||
res_left = left - (x - 2)
|
||||
res_right = right - (x + 3) + 5
|
||||
res_top = top - (y - 4)
|
||||
res_bottom = 5
|
||||
|
||||
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
||||
# 没画到的地方要置为 0
|
||||
result[:, 4, 2] = 0
|
||||
result[:, 4, 3] = 0
|
||||
result[:, 4, 4] = 0
|
||||
mask[:, res_top:res_bottom, res_left:res_right] = 1
|
||||
mask[:, 4, 2] = 0
|
||||
mask[:, 4, 3] = 0
|
||||
mask[:, 4, 4] = 0
|
||||
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device)
|
||||
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||
masked_result[:, 32] = mask
|
||||
|
||||
feat = self.patch_cnn(masked_result)
|
||||
feat = self.fc(feat)
|
||||
return feat
|
||||
|
||||
class GinkaTileEmbedding(nn.Module):
|
||||
def __init__(self, tile_classes=32, embed_dim=256):
|
||||
super().__init__()
|
||||
|
||||
# 图块编码,上一次画的图块
|
||||
|
||||
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
||||
|
||||
def forward(self, tile: torch.Tensor):
|
||||
return self.embedding(tile)
|
||||
|
||||
class GinkaPosEmbedding(nn.Module):
|
||||
def __init__(self, width=13, height=13, embed_dim=256):
|
||||
super().__init__()
|
||||
|
||||
# 位置编码
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
self.row_embedding = nn.Embedding(width, embed_dim)
|
||||
self.col_embedding = nn.Embedding(height, embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
row = self.row_embedding(x).squeeze(1)
|
||||
col = self.col_embedding(y).squeeze(1)
|
||||
|
||||
return row, col
|
||||
|
||||
class GinkaInputFusion(nn.Module):
|
||||
def __init__(self, d_model=256):
|
||||
super().__init__()
|
||||
|
||||
# 使用 Transformer 进行信息整合
|
||||
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
|
||||
dropout=0.2
|
||||
),
|
||||
num_layers=4
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
):
|
||||
"""
|
||||
tile_embed: [B, 256]
|
||||
cond_vec: [B, 256]
|
||||
row_embed: [B, 256]
|
||||
col_embed: [B, 256]
|
||||
patch_vec: [B, 256]
|
||||
"""
|
||||
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||
feat = self.transformer(vec)
|
||||
return feat[:, 0]
|
||||
|
||||
class GinkaRNN(nn.Module):
|
||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
||||
super().__init__()
|
||||
|
||||
# GRU
|
||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
||||
self.drop = nn.Dropout(0.2)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(hidden_dim, tile_classes)
|
||||
)
|
||||
|
||||
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
|
||||
"""
|
||||
feat_fusion: [B, input_dim]
|
||||
hidden: [B, hidden_dim]
|
||||
"""
|
||||
hidden = self.drop(self.gru(feat_fusion, hidden))
|
||||
logits = self.fc(hidden)
|
||||
return logits, hidden
|
||||
|
||||
class VAEDecoder(nn.Module):
|
||||
def __init__(self, device: torch.device, start_tile=31, map_vec_dim=32, width=13, height=13):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.start_tile = start_tile
|
||||
|
||||
self.rnn_hidden = 512
|
||||
self.tile_classes = 32
|
||||
|
||||
# 模型结构
|
||||
self.map_vec_fc = nn.Sequential(
|
||||
nn.Linear(map_vec_dim, 256)
|
||||
)
|
||||
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
|
||||
self.pos_embedding = GinkaPosEmbedding()
|
||||
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
|
||||
self.feat_fusion = GinkaInputFusion()
|
||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||
|
||||
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
||||
"""
|
||||
map_vec: [B, vec_dim]
|
||||
target_map: [B, H, W]
|
||||
use_self: 是否使用自己生成的上一步结果执行下一步
|
||||
"""
|
||||
B, C = map_vec.shape
|
||||
|
||||
# 张量声明
|
||||
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
|
||||
|
||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
|
||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
||||
|
||||
map_vec = self.map_vec_fc(map_vec)
|
||||
|
||||
for y in range(0, self.height):
|
||||
for x in range(0, self.width):
|
||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(now_tile)
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
use_self = random.random() < use_self_probility
|
||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||
# 编码特征融合
|
||||
feat = self.feat_fusion(tile_embed, map_vec, row_embed, col_embed, map_patch)
|
||||
# RNN 输出
|
||||
logits, h = self.rnn(feat, hidden)
|
||||
# 处理输出
|
||||
output_logits[:, y, x] = logits[:]
|
||||
hidden = h
|
||||
tile_id = torch.argmax(logits, dim=1).detach()
|
||||
map[:, y, x] = tile_id[:]
|
||||
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
||||
|
||||
return output_logits.permute(0, 3, 1, 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||
map_vec = torch.rand(1, 32).to(device)
|
||||
|
||||
# 初始化模型
|
||||
model = VAEDecoder("cpu").to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
fake_logits, fake_map = model(map_vec, input, 0)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}")
|
||||
print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
|
||||
print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}")
|
||||
print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}")
|
||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
55
ginka/vae_rnn/encoder.py
Normal file
55
ginka/vae_rnn/encoder.py
Normal file
@ -0,0 +1,55 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..utils import print_memory
|
||||
|
||||
class VAEEncoder(nn.Module):
|
||||
def __init__(self, tile_classes=32, latent_dim=32):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(tile_classes, 64, 3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(64, 128, 3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(128, 256, 3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
|
||||
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.conv(x)
|
||||
mu = self.fc_mu(h)
|
||||
logvar = self.fc_logvar(h)
|
||||
return mu, logvar
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
||||
input = F.one_hot(input, 32).permute(0, 3, 1, 2).float()
|
||||
|
||||
# 初始化模型
|
||||
model = VAEEncoder().to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
mu, logvar = model(input)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
||||
print(f"CNN parameters: {sum(p.numel() for p in model.conv.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
17
ginka/vae_rnn/loss.py
Normal file
17
ginka/vae_rnn/loss.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class VAELoss:
|
||||
def __init__(self):
|
||||
self.num_classes = 32
|
||||
|
||||
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
||||
# target: [B, 13, 13]
|
||||
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
|
||||
recon_loss = F.cross_entropy(logits, target)
|
||||
|
||||
kl_loss = -0.5 * torch.mean(
|
||||
1 + logvar - mu.pow(2) - logvar.exp()
|
||||
)
|
||||
|
||||
return recon_loss + beta * kl_loss, recon_loss, kl_loss
|
||||
23
ginka/vae_rnn/vae.py
Normal file
23
ginka/vae_rnn/vae.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .encoder import VAEEncoder
|
||||
from .decoder import VAEDecoder
|
||||
|
||||
class GinkaVAE(nn.Module):
|
||||
def __init__(self, device, tile_classes=32, latent_dim=32):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(tile_classes, latent_dim)
|
||||
self.decoder = VAEDecoder(device)
|
||||
|
||||
def reparameterize(self, mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def forward(self, target_map: torch.Tensor, use_self_probility=0):
|
||||
target = F.one_hot(target_map, num_classes=32).float().permute(0, 3, 1, 2)
|
||||
mu, logvar = self.encoder(target)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
logits = self.decoder(z, target_map, use_self_probility)
|
||||
return logits, mu, logvar
|
||||
Loading…
Reference in New Issue
Block a user