mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: rnn 随机采用自身生成的内容
This commit is contained in:
parent
d0decfc63a
commit
cfc022724a
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -182,7 +183,7 @@ class GinkaRNNModel(nn.Module):
|
||||
self.feat_fusion = GinkaInputFusion()
|
||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||
|
||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
|
||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
||||
"""
|
||||
val_cond: [B, val_dim]
|
||||
target_map: [B, H, W]
|
||||
@ -207,6 +208,7 @@ class GinkaRNNModel(nn.Module):
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(now_tile)
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
use_self = random.random() < use_self_probility
|
||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||
# 编码特征融合
|
||||
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
||||
|
||||
@ -7,19 +7,18 @@ import torch.optim as optim
|
||||
import cv2
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .common.cond import ConditionEncoder
|
||||
from .generator.rnn import GinkaRNNModel
|
||||
from .dataset import GinkaRNNDataset
|
||||
from .generator.loss import RNNGinkaLoss
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# 手工标注标签定义:
|
||||
# 手工标注标签定义(暂时不用):
|
||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
||||
|
||||
# 自动标注标签定义:
|
||||
# 自动标注标签定义(暂时不用):
|
||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
||||
|
||||
@ -59,6 +58,10 @@ os.makedirs("result/ginka_rnn_img", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
||||
progress = epoch / max_epoch
|
||||
return 0.1 + 0.9 * progress
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
@ -114,7 +117,7 @@ def train():
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
|
||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||
|
||||
@ -158,7 +161,7 @@ def train():
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epoch))
|
||||
|
||||
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user