mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
feat: rnn 随机采用自身生成的内容
This commit is contained in:
parent
d0decfc63a
commit
cfc022724a
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -182,7 +183,7 @@ class GinkaRNNModel(nn.Module):
|
|||||||
self.feat_fusion = GinkaInputFusion()
|
self.feat_fusion = GinkaInputFusion()
|
||||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||||
|
|
||||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
|
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
||||||
"""
|
"""
|
||||||
val_cond: [B, val_dim]
|
val_cond: [B, val_dim]
|
||||||
target_map: [B, H, W]
|
target_map: [B, H, W]
|
||||||
@ -207,6 +208,7 @@ class GinkaRNNModel(nn.Module):
|
|||||||
# 位置编码、图块编码、地图局部编码
|
# 位置编码、图块编码、地图局部编码
|
||||||
tile_embed = self.tile_embedding(now_tile)
|
tile_embed = self.tile_embedding(now_tile)
|
||||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||||
|
use_self = random.random() < use_self_probility
|
||||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||||
# 编码特征融合
|
# 编码特征融合
|
||||||
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
||||||
|
|||||||
@ -7,19 +7,18 @@ import torch.optim as optim
|
|||||||
import cv2
|
import cv2
|
||||||
from torch_geometric.loader import DataLoader
|
from torch_geometric.loader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .common.cond import ConditionEncoder
|
|
||||||
from .generator.rnn import GinkaRNNModel
|
from .generator.rnn import GinkaRNNModel
|
||||||
from .dataset import GinkaRNNDataset
|
from .dataset import GinkaRNNDataset
|
||||||
from .generator.loss import RNNGinkaLoss
|
from .generator.loss import RNNGinkaLoss
|
||||||
from shared.image import matrix_to_image_cv
|
from shared.image import matrix_to_image_cv
|
||||||
|
|
||||||
# 手工标注标签定义:
|
# 手工标注标签定义(暂时不用):
|
||||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
||||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
||||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
||||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
||||||
|
|
||||||
# 自动标注标签定义:
|
# 自动标注标签定义(暂时不用):
|
||||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
||||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
||||||
|
|
||||||
@ -59,6 +58,10 @@ os.makedirs("result/ginka_rnn_img", exist_ok=True)
|
|||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
disable_tqdm = not sys.stdout.isatty()
|
||||||
|
|
||||||
|
def gt_prob(epoch: int, max_epoch: int) -> float:
|
||||||
|
progress = epoch / max_epoch
|
||||||
|
return 0.1 + 0.9 * progress
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
parser = argparse.ArgumentParser(description="training codes")
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
@ -114,7 +117,7 @@ def train():
|
|||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
|
||||||
|
|
||||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||||
|
|
||||||
@ -158,7 +161,7 @@ def train():
|
|||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epoch))
|
||||||
|
|
||||||
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user