feat: rnn 随机采用自身生成的内容

This commit is contained in:
unanmed 2026-01-16 23:01:37 +08:00
parent d0decfc63a
commit cfc022724a
2 changed files with 11 additions and 6 deletions

View File

@ -1,4 +1,5 @@
import time import time
import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -182,7 +183,7 @@ class GinkaRNNModel(nn.Module):
self.feat_fusion = GinkaInputFusion() self.feat_fusion = GinkaInputFusion()
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False): def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
""" """
val_cond: [B, val_dim] val_cond: [B, val_dim]
target_map: [B, H, W] target_map: [B, H, W]
@ -207,6 +208,7 @@ class GinkaRNNModel(nn.Module):
# 位置编码、图块编码、地图局部编码 # 位置编码、图块编码、地图局部编码
tile_embed = self.tile_embedding(now_tile) tile_embed = self.tile_embedding(now_tile)
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
use_self = random.random() < use_self_probility
map_patch = self.map_patch(map if use_self else target_map, x, y) map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合 # 编码特征融合
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch) feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)

View File

@ -7,19 +7,18 @@ import torch.optim as optim
import cv2 import cv2
from torch_geometric.loader import DataLoader from torch_geometric.loader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from .common.cond import ConditionEncoder
from .generator.rnn import GinkaRNNModel from .generator.rnn import GinkaRNNModel
from .dataset import GinkaRNNDataset from .dataset import GinkaRNNDataset
from .generator.loss import RNNGinkaLoss from .generator.loss import RNNGinkaLoss
from shared.image import matrix_to_image_cv from shared.image import matrix_to_image_cv
# 手工标注标签定义 # 手工标注标签定义(暂时不用)
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, # 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 # 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 # 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
# 24. 中心对称, 25. 部分对称, 26. 鱼骨 # 24. 中心对称, 25. 部分对称, 26. 鱼骨
# 自动标注标签定义 # 自动标注标签定义(暂时不用)
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 # 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
# 32. 平面塔, 33. 转换塔, 34. 道具塔 # 32. 平面塔, 33. 转换塔, 34. 道具塔
@ -59,6 +58,10 @@ os.makedirs("result/ginka_rnn_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty() disable_tqdm = not sys.stdout.isatty()
def gt_prob(epoch: int, max_epoch: int) -> float:
progress = epoch / max_epoch
return 0.1 + 0.9 * progress
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="training codes") parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--resume", type=bool, default=False)
@ -114,7 +117,7 @@ def train():
val_cond = batch["val_cond"].to(device) val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device) target_map = batch["target_map"].to(device)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
loss = criterion.rnn_loss(fake_logits, target_map) loss = criterion.rnn_loss(fake_logits, target_map)
@ -158,7 +161,7 @@ def train():
val_cond = batch["val_cond"].to(device) val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device) target_map = batch["target_map"].to(device)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epoch))
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach() val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()